|
@@ -0,0 +1,1715 @@
|
|
|
|
+/******************************************************************************
|
|
|
|
+ * Copyright (c) 2024, Tri Dao.
|
|
|
|
+ ******************************************************************************/
|
|
|
|
+
|
|
|
|
+#pragma once
|
|
|
|
+
|
|
|
|
+#include <cute/tensor.hpp>
|
|
|
|
+
|
|
|
|
+#include <cutlass/cutlass.h>
|
|
|
|
+#include <cutlass/array.h>
|
|
|
|
+#include <cutlass/numeric_types.h>
|
|
|
|
+
|
|
|
|
+#include "block_info.h"
|
|
|
|
+#include "kernel_traits.h"
|
|
|
|
+#include "utils.h"
|
|
|
|
+#include "softmax.h"
|
|
|
|
+#include "mask.h"
|
|
|
|
+#include "dropout.h"
|
|
|
|
+#include "rotary.h"
|
|
|
|
+
|
|
|
|
+namespace flash {
|
|
|
|
+
|
|
|
|
+using namespace cute;
|
|
|
|
+
|
|
|
|
+template <typename Engine, typename Layout>
|
|
|
|
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout>& tensor,
|
|
|
|
+ const float softcap) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int i = 0; i < size(tensor); ++i) {
|
|
|
|
+ tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+template <typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
|
|
|
+__forceinline__ __device__ auto get_lse_tile(
|
|
|
|
+ const Params& params, const int bidb, const int bidh, const int m_block,
|
|
|
|
+ const BlockInfo</*Varlen=*/!Is_even_MN>& binfo) {
|
|
|
|
+ // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) -
|
|
|
|
+ // this is non-variable seqlen path. Otherwise, when
|
|
|
|
+ // params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b)
|
|
|
|
+ // to account for seqlen_q <-> h swapping trick. Otherwise, it's written as
|
|
|
|
+ // (h, b, seqlen_q).
|
|
|
|
+ const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
|
|
|
|
+ auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
|
|
|
|
+ auto gmem_ptr_lse = make_gmem_ptr(
|
|
|
|
+ reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
|
|
|
|
+
|
|
|
|
+ auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q)
|
|
|
|
+ : make_shape(params.b, params.h, params.seqlen_q);
|
|
|
|
+ auto lse_stride =
|
|
|
|
+ params.seqlenq_ngroups_swapped
|
|
|
|
+ ? make_stride(1, params.seqlen_q * params.b, params.b)
|
|
|
|
+ : (params.unpadded_lse
|
|
|
|
+ ? make_stride(params.h * params.total_q, params.total_q, 1)
|
|
|
|
+ : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1));
|
|
|
|
+
|
|
|
|
+ auto lse_layout = make_layout(lse_shape, lse_stride);
|
|
|
|
+ Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
|
|
|
|
+ auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
|
|
|
|
+ return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
|
|
|
|
+ bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
|
|
|
|
+ bool Is_softcap, bool Return_softmax, typename Params>
|
|
|
|
+inline __device__ void compute_attn_1rowblock(const Params& params,
|
|
|
|
+ const int bidb, const int bidh,
|
|
|
|
+ const int m_block) {
|
|
|
|
+ using Element = typename Kernel_traits::Element;
|
|
|
|
+ using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
|
|
+ using index_t = typename Kernel_traits::index_t;
|
|
|
|
+
|
|
|
|
+ // Shared memory.
|
|
|
|
+ extern __shared__ char smem_[];
|
|
|
|
+
|
|
|
|
+ // The thread index.
|
|
|
|
+ const int tidx = threadIdx.x;
|
|
|
|
+
|
|
|
|
+ constexpr int kBlockM = Kernel_traits::kBlockM;
|
|
|
|
+ constexpr int kBlockN = Kernel_traits::kBlockN;
|
|
|
|
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
+ constexpr int kNWarps = Kernel_traits::kNWarps;
|
|
|
|
+
|
|
|
|
+ auto seed_offset = at::cuda::philox::unpack(params.philox_args);
|
|
|
|
+ flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset),
|
|
|
|
+ params.p_dropout_in_uint8_t, bidb, bidh, tidx,
|
|
|
|
+ params.h);
|
|
|
|
+
|
|
|
|
+ // Save seed and offset for backward, before any early exiting. Otherwise the
|
|
|
|
+ // 0-th thread block might exit early and no one saves the rng states.
|
|
|
|
+ if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
|
|
|
|
+ tidx == 0) {
|
|
|
|
+ params.rng_state[0] = std::get<0>(seed_offset);
|
|
|
|
+ params.rng_state[1] = std::get<1>(seed_offset);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
|
|
|
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
+
|
|
|
|
+ const int n_block_min =
|
|
|
|
+ !Is_local
|
|
|
|
+ ? 0
|
|
|
|
+ : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k -
|
|
|
|
+ binfo.actual_seqlen_q - params.window_size_left) /
|
|
|
|
+ kBlockN);
|
|
|
|
+ int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
|
|
|
+ if (Is_causal || Is_local) {
|
|
|
|
+ n_block_max = std::min(
|
|
|
|
+ n_block_max,
|
|
|
|
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
|
|
|
|
+ binfo.actual_seqlen_q + params.window_size_right,
|
|
|
|
+ kBlockN));
|
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
|
|
|
+ // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
|
|
|
+ // }
|
|
|
|
+ }
|
|
|
|
+ // We exit early and write 0 to gO and gLSE. This also covers the case where
|
|
|
|
+ // actual_seqlen_k == 0. Otherwise we might read OOB elements from gK and gV.
|
|
|
|
+ if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
|
|
|
|
+ Tensor mO = make_tensor(
|
|
|
|
+ make_gmem_ptr(
|
|
|
|
+ reinterpret_cast<Element*>(params.o_ptr) +
|
|
|
|
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
|
|
|
|
+ make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
|
|
|
+ Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
|
+
|
|
|
|
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
|
|
|
|
+ params, bidb, bidh, m_block, binfo);
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
|
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
|
|
+ Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
|
|
|
+ clear(tOrO);
|
|
|
|
+ // Construct identity layout for sO
|
|
|
|
+ Tensor cO = make_identity_tensor(make_shape(
|
|
|
|
+ size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
|
|
|
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) {
|
|
|
|
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
|
|
|
|
+ /*Clear_OOB_K=*/false>(
|
|
|
|
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int m = 0; m < size<1>(tOgO); ++m) {
|
|
|
|
+ const int row = get<0>(tOcO(0, m, 0));
|
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
|
|
|
|
+ get<1>(tOcO(0, m, 0)) == 0) {
|
|
|
|
+ gLSE(row) = INFINITY;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+ // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max =
|
|
|
|
+ // %d\n", m_block, n_block_min, n_block_max); }
|
|
|
|
+
|
|
|
|
+ // We iterate over the blocks in reverse order. This is because the last block
|
|
|
|
+ // is the only one that needs masking when we read K and V from global memory.
|
|
|
|
+ // Moreover, iterating in reverse might save us 1 register (we just need
|
|
|
|
+ // n_block instead of both n_block and n_block_max).
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_p =
|
|
|
|
+ ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) *
|
|
|
|
+ params.seqlen_k_rounded +
|
|
|
|
+ (n_block_max - 1) * kBlockN;
|
|
|
|
+
|
|
|
|
+ Tensor mQ =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
|
|
|
|
+ binfo.q_offset(params.q_batch_stride,
|
|
|
|
+ params.q_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
|
|
|
|
+ make_stride(params.q_row_stride, params.q_head_stride, _1{}));
|
|
|
|
+ Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
|
+ Tensor mK =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) +
|
|
|
|
+ binfo.k_offset(params.k_batch_stride,
|
|
|
|
+ params.k_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
|
|
|
|
+ make_stride(params.k_row_stride, params.k_head_stride, _1{}));
|
|
|
|
+ Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
|
|
|
|
+ Tensor mV =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) +
|
|
|
|
+ binfo.k_offset(params.v_batch_stride,
|
|
|
|
+ params.v_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
|
|
|
|
+ make_stride(params.v_row_stride, params.v_head_stride, _1{}));
|
|
|
|
+ Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
|
|
|
|
+ Tensor gP = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.p_ptr) + row_offset_p),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kBlockN>>{},
|
|
|
|
+ make_stride(params.seqlen_k_rounded, _1{}));
|
|
|
|
+
|
|
|
|
+ Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
|
|
|
|
+ typename Kernel_traits::SmemLayoutQ{});
|
|
|
|
+ // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
|
|
|
+ Tensor sK =
|
|
|
|
+ make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
|
|
|
+ typename Kernel_traits::SmemLayoutKV{});
|
|
|
|
+
|
|
|
|
+ Tensor sV =
|
|
|
|
+ make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
|
|
|
+ Tensor sVt =
|
|
|
|
+ make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
|
|
|
+ Tensor sVtNoSwizzle =
|
|
|
|
+ make_tensor(sV.data().get(),
|
|
|
|
+ typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
|
|
|
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
|
|
|
+
|
|
|
|
+ Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
|
|
|
+ Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
|
|
|
+ Tensor tKgK =
|
|
|
|
+ gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
|
|
|
|
+ Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
|
|
|
+ Tensor tVgV =
|
|
|
|
+ gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
|
|
|
|
+ Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::TiledMma tiled_mma;
|
|
|
|
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
|
|
|
+ Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
|
|
|
+ Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
|
|
|
+ Tensor tOrVt =
|
|
|
|
+ thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
|
|
|
+
|
|
|
|
+ Tensor tSgS = thr_mma.partition_C(gP);
|
|
|
|
+
|
|
|
|
+ Tensor acc_o = partition_fragment_C(
|
|
|
|
+ tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
|
|
|
+
|
|
|
|
+ //
|
|
|
|
+ // Copy Atom retiling
|
|
|
|
+ //
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_Q =
|
|
|
|
+ make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
|
|
|
+ // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
|
|
|
+ Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
|
|
|
+ // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_K =
|
|
|
|
+ make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
|
|
|
+ Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_V = make_tiled_copy_B(
|
|
|
|
+ typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
|
|
|
+
|
|
|
|
+ //
|
|
|
|
+ // PREDICATES
|
|
|
|
+ //
|
|
|
|
+
|
|
|
|
+ // // Allocate predicate tensors for m and n
|
|
|
|
+ // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
|
|
|
|
+ // Stride<_1,_0>{}); Tensor tKVpKV =
|
|
|
|
+ // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
|
|
|
|
+ // Stride<_1,_0>{});
|
|
|
|
+
|
|
|
|
+ // Construct identity layout for sQ and sK
|
|
|
|
+ Tensor cQ = make_identity_tensor(
|
|
|
|
+ make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor cKV = make_identity_tensor(
|
|
|
|
+ make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
|
|
|
+ // Tensor tScQ = thr_mma.partition_A(cQ); //
|
|
|
|
+ // (MMA,MMA_M,MMA_K) if (cute::thread0()) {
|
|
|
|
+ // print(tScQ.layout()); printf("\n");
|
|
|
|
+ // for (int i = 0; i < size(tScQ); ++i) {
|
|
|
|
+ // printf("%d ", get<0>(tScQ(i)));
|
|
|
|
+ // }
|
|
|
|
+ // printf("\n");
|
|
|
|
+ // for (int i = 0; i < size(tScQ); ++i) {
|
|
|
|
+ // printf("%d ", get<1>(tScQ(i)));
|
|
|
|
+ // }
|
|
|
|
+ // printf("\n");
|
|
|
|
+ // }
|
|
|
|
+
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tQcQ = gmem_thr_copy_QKV.partition_S(
|
|
|
|
+ cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(
|
|
|
|
+ cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
|
|
|
+
|
|
|
|
+ // Allocate predicate tensors for k
|
|
|
|
+ Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
|
|
|
+ Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
|
|
|
+
|
|
|
|
+ // Set predicates for k bounds
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tQpQ); ++k) {
|
|
|
|
+ tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tKVpKV); ++k) {
|
|
|
|
+ tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Prologue
|
|
|
|
+
|
|
|
|
+ // We don't need to clear the sQ smem tiles since we'll only write out the
|
|
|
|
+ // valid outputs
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ,
|
|
|
|
+ tQpQ,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+ if (Kernel_traits::Is_Q_in_regs) {
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // // if (cute::thread(1, 0)) { print(tQsQ); }
|
|
|
|
+ // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element
|
|
|
|
+ // *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
|
|
|
+ // // if (cute::thread0()) { print(sQNoSwizzle); }
|
|
|
|
+
|
|
|
|
+ if (Kernel_traits::Share_Q_K_smem) {
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
|
|
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
|
|
|
+ __syncthreads();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ int n_block = n_block_max - 1;
|
|
|
|
+ // We don't need to clear the sK smem tiles since we'll mask out the scores
|
|
|
|
+ // anyway.
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K>(
|
|
|
|
+ gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
|
|
|
+ // __syncthreads();
|
|
|
|
+
|
|
|
|
+ if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
|
|
|
+ flash::cp_async_wait<1>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
|
+ CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
|
|
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ clear(acc_o);
|
|
|
|
+
|
|
|
|
+ flash::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
|
+
|
|
|
|
+ const float alibi_slope =
|
|
|
|
+ !Has_alibi || params.alibi_slopes_ptr == nullptr
|
|
|
|
+ ? 0.0f
|
|
|
|
+ : reinterpret_cast<float*>(params.alibi_slopes_ptr)
|
|
|
|
+ [bidb * params.alibi_slopes_batch_stride + bidh] /
|
|
|
|
+ params.scale_softmax;
|
|
|
|
+ flash::Mask<Is_causal, Is_local, Has_alibi> mask(
|
|
|
|
+ binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
|
|
|
|
+ params.window_size_right, alibi_slope);
|
|
|
|
+
|
|
|
|
+ // For performance reason, we separate out two kinds of iterations:
|
|
|
|
+ // those that need masking on S, and those that don't.
|
|
|
|
+ // We need masking on S for the very last block when K and V has length not
|
|
|
|
+ // multiple of kBlockN. We also need masking on S if it's causal, for the last
|
|
|
|
+ // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
|
|
|
|
+ // iteration.
|
|
|
|
+
|
|
|
|
+ // If not even_N, then seqlen_k might end in the middle of a block. In that
|
|
|
|
+ // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
|
|
|
+ constexpr int n_masking_steps =
|
|
|
|
+ (!Is_causal && !Is_local)
|
|
|
|
+ ? 1
|
|
|
|
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
|
|
|
|
+ : cute::ceil_div(kBlockM, kBlockN) + 1);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int masking_step = 0; masking_step < n_masking_steps;
|
|
|
|
+ ++masking_step, --n_block) {
|
|
|
|
+ Tensor acc_s = partition_fragment_C(
|
|
|
|
+ tiled_mma,
|
|
|
|
+ Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
|
+ clear(acc_s);
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+
|
|
|
|
+ // Advance gV
|
|
|
|
+ if (masking_step > 0) {
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(
|
|
|
|
+ gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
|
+ } else {
|
|
|
|
+ // Clear the smem tiles to account for predicated off loads
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
|
+ gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
|
+ }
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+
|
|
|
|
+ flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
|
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
|
|
|
|
+ smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
|
|
|
|
+ // if (cute::thread0()) { print(acc_s); }
|
|
|
|
+ if constexpr (Is_softcap) {
|
|
|
|
+ apply_softcap(acc_s, params.softcap);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ mask.template apply_mask<Is_causal, Is_even_MN>(
|
|
|
|
+ acc_s, n_block * kBlockN,
|
|
|
|
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
|
|
|
+
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ if (n_block > n_block_min) {
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
|
|
|
|
+ tKgK(_, _, _, n_block - 1),
|
|
|
|
+ tKsK, tKVcKV, tKVpKV);
|
|
|
|
+ // This cp_async_fence needs to be in the if block, otherwise the
|
|
|
|
+ // synchronization isn't right and we get race conditions.
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // TODO: when we have key_padding_mask we'll need to Check_inf
|
|
|
|
+ masking_step == 0
|
|
|
|
+ ? softmax.template softmax_rescale_o<
|
|
|
|
+ /*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2)
|
|
|
|
+ : softmax.template softmax_rescale_o<
|
|
|
|
+ /*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
+
|
|
|
|
+ // Convert acc_s from fp32 to fp16/bf16
|
|
|
|
+ Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
|
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
|
|
|
+ int block_col_idx = n_block * (kBlockN / 32);
|
|
|
|
+ if (Return_softmax) {
|
|
|
|
+ Tensor rP_drop = make_fragment_like(rP);
|
|
|
|
+ cute::copy(rP, rP_drop);
|
|
|
|
+ dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
|
|
|
+ rP_drop, block_row_idx, block_col_idx, kNWarps);
|
|
|
|
+ cute::copy(rP_drop, tSgS);
|
|
|
|
+ tSgS.data() = tSgS.data() + (-kBlockN);
|
|
|
|
+ }
|
|
|
|
+ if (Is_dropout) {
|
|
|
|
+ dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
|
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
|
+ Tensor tOrP = make_tensor(
|
|
|
|
+ rP.data(),
|
|
|
|
+ flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
+ // if (cute::thread0()) { print(tOrP); }
|
|
|
|
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
|
|
|
|
+ smem_thr_copy_V);
|
|
|
|
+ // if (cute::thread0()) { print(scores); }
|
|
|
|
+
|
|
|
|
+ // This check is at the end of the loop since we always have at least 1
|
|
|
|
+ // iteration
|
|
|
|
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
|
|
|
|
+ --n_block;
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // These are the iterations where we don't need masking on S
|
|
|
|
+ for (; n_block >= n_block_min; --n_block) {
|
|
|
|
+ Tensor acc_s = partition_fragment_C(
|
|
|
|
+ tiled_mma,
|
|
|
|
+ Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
|
+ clear(acc_s);
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(
|
|
|
|
+ gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+
|
|
|
|
+ flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
|
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
|
|
|
|
+ smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
|
|
|
|
+ if constexpr (Is_softcap) {
|
|
|
|
+ apply_softcap(acc_s, params.softcap);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ if (n_block > n_block_min) {
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
|
|
|
|
+ tKgK(_, _, _, n_block - 1),
|
|
|
|
+ tKsK, tKVcKV, tKVpKV);
|
|
|
|
+ // This cp_async_fence needs to be in the if block, otherwise the
|
|
|
|
+ // synchronization isn't right and we get race conditions.
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ mask.template apply_mask</*Causal_mask=*/false>(
|
|
|
|
+ acc_s, n_block * kBlockN,
|
|
|
|
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
|
|
|
+
|
|
|
|
+ softmax
|
|
|
|
+ .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
+
|
|
|
|
+ Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
|
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
|
|
|
+ int block_col_idx = n_block * (kBlockN / 32);
|
|
|
|
+ if (Return_softmax) {
|
|
|
|
+ Tensor rP_drop = make_fragment_like(rP);
|
|
|
|
+ cute::copy(rP, rP_drop);
|
|
|
|
+ dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
|
|
|
+ rP_drop, block_row_idx, block_col_idx, kNWarps);
|
|
|
|
+ cute::copy(rP_drop, tSgS);
|
|
|
|
+ tSgS.data() = tSgS.data() + (-kBlockN);
|
|
|
|
+ }
|
|
|
|
+ if (Is_dropout) {
|
|
|
|
+ dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
|
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
|
+ Tensor tOrP = make_tensor(
|
|
|
|
+ rP.data(),
|
|
|
|
+ flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
|
|
|
|
+ smem_thr_copy_V);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Epilogue
|
|
|
|
+
|
|
|
|
+ Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(
|
|
|
|
+ acc_o, params.scale_softmax, params.rp_dropout);
|
|
|
|
+
|
|
|
|
+ // Convert acc_o from fp32 to fp16/bf16
|
|
|
|
+ Tensor rO = flash::convert_type<Element>(acc_o);
|
|
|
|
+ Tensor sO = make_tensor(
|
|
|
|
+ sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
|
|
+ // Partition sO to match the accumulator partitioning
|
|
|
|
+ auto smem_tiled_copy_O =
|
|
|
|
+ make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
|
|
|
+ Tensor taccOrO =
|
|
|
|
+ smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
|
|
|
+ Tensor taccOsO =
|
|
|
|
+ smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
|
|
+
|
|
|
|
+ // sO has the same size as sQ, so we don't need to sync here.
|
|
|
|
+ if (Kernel_traits::Share_Q_K_smem) {
|
|
|
|
+ __syncthreads();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
|
|
|
+
|
|
|
|
+ Tensor mO =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) +
|
|
|
|
+ binfo.q_offset(params.o_batch_stride,
|
|
|
|
+ params.o_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
|
|
|
|
+ make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
|
|
|
+ Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
|
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
|
|
|
|
+ params, bidb, bidh, m_block, binfo);
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
|
|
|
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOsO =
|
|
|
|
+ gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
|
|
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
|
|
|
+
|
|
|
|
+ __syncthreads();
|
|
|
|
+
|
|
|
|
+ Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
|
|
|
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
|
|
|
+
|
|
|
|
+ Tensor caccO = make_identity_tensor(
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
|
|
|
+ static_assert(decltype(size<0>(taccOcO))::value == 4);
|
|
|
|
+ // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
|
|
|
+ Tensor taccOcO_row =
|
|
|
|
+ logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
|
|
|
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
|
|
|
+ if (get<1>(taccOcO_row(0)) == 0) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int mi = 0; mi < size(lse); ++mi) {
|
|
|
|
+ const int row = get<0>(taccOcO_row(mi));
|
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
|
|
|
|
+ gLSE(row) = lse(mi);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Construct identity layout for sO
|
|
|
|
+ Tensor cO = make_identity_tensor(
|
|
|
|
+ make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tOcO =
|
|
|
|
+ gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) {
|
|
|
|
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
|
|
|
|
+ /*Clear_OOB_K=*/false>(gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
|
|
|
|
+ bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
|
|
|
|
+ bool Append_KV, typename Params>
|
|
|
|
+inline __device__ void compute_attn_1rowblock_splitkv(
|
|
|
|
+ const Params& params, const int bidb, const int bidh, const int m_block,
|
|
|
|
+ const int n_split_idx, const int num_n_splits) {
|
|
|
|
+ using Element = typename Kernel_traits::Element;
|
|
|
|
+ using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
|
|
+ using index_t = typename Kernel_traits::index_t;
|
|
|
|
+
|
|
|
|
+ // Shared memory.
|
|
|
|
+ extern __shared__ char smem_[];
|
|
|
|
+
|
|
|
|
+ // The thread index.
|
|
|
|
+ const int tidx = threadIdx.x;
|
|
|
|
+
|
|
|
|
+ constexpr int kBlockM = Kernel_traits::kBlockM;
|
|
|
|
+ constexpr int kBlockN = Kernel_traits::kBlockN;
|
|
|
|
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
+ constexpr int kNWarps = Kernel_traits::kNWarps;
|
|
|
|
+
|
|
|
|
+ using GmemTiledCopyO =
|
|
|
|
+ std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO,
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyOaccum>;
|
|
|
|
+ using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
|
|
|
+
|
|
|
|
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
|
|
|
+ // printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d,
|
|
|
|
+ // actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative,
|
|
|
|
+ // binfo.seqlen_k_cache, binfo.actual_seqlen_k); } if (threadIdx.x == 0 &&
|
|
|
|
+ // blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p,
|
|
|
|
+ // seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache
|
|
|
|
+ // + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
|
|
|
|
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
|
|
|
+
|
|
|
|
+ const int n_blocks_per_split =
|
|
|
|
+ ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) /
|
|
|
|
+ num_n_splits;
|
|
|
|
+ const int n_block_min =
|
|
|
|
+ !Is_local ? n_split_idx * n_blocks_per_split
|
|
|
|
+ : std::max(n_split_idx * n_blocks_per_split,
|
|
|
|
+ (m_block * kBlockM + binfo.actual_seqlen_k -
|
|
|
|
+ binfo.actual_seqlen_q - params.window_size_left) /
|
|
|
|
+ kBlockN);
|
|
|
|
+ int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN),
|
|
|
|
+ (n_split_idx + 1) * n_blocks_per_split);
|
|
|
|
+ if (Is_causal || Is_local) {
|
|
|
|
+ n_block_max = std::min(
|
|
|
|
+ n_block_max,
|
|
|
|
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
|
|
|
|
+ binfo.actual_seqlen_q + params.window_size_right,
|
|
|
|
+ kBlockN));
|
|
|
|
+ }
|
|
|
|
+ if (n_block_min >=
|
|
|
|
+ n_block_max) { // This also covers the case where n_block_max <= 0
|
|
|
|
+ // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
|
|
|
|
+ // Otherwise we might read OOB elements from gK and gV,
|
|
|
|
+ // or get wrong results when we combine gOaccum from different blocks.
|
|
|
|
+ const index_t row_offset_o =
|
|
|
|
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
|
|
|
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
|
+ const index_t row_offset_oaccum =
|
|
|
|
+ (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
|
|
|
|
+ m_block * kBlockM) *
|
|
|
|
+ params.d_rounded;
|
|
|
|
+ const index_t row_offset_lseaccum =
|
|
|
|
+ ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
|
|
|
|
+ m_block * kBlockM;
|
|
|
|
+ Tensor gOaccum = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<ElementO*>(Split ? params.oaccum_ptr
|
|
|
|
+ : params.o_ptr) +
|
|
|
|
+ (Split ? row_offset_oaccum : row_offset_o)),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
|
|
|
|
+ Tensor gLSEaccum = make_tensor(
|
|
|
|
+ make_gmem_ptr(
|
|
|
|
+ reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
|
|
|
|
+ : params.softmax_lse_ptr) +
|
|
|
|
+ row_offset_lseaccum),
|
|
|
|
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
+
|
|
|
|
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
|
|
|
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
|
|
|
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
|
|
|
+ clear(tOrOaccum);
|
|
|
|
+ // Construct identity layout for sO
|
|
|
|
+ Tensor cO = make_identity_tensor(make_shape(
|
|
|
|
+ size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
|
|
|
|
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) {
|
|
|
|
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
|
|
|
|
+ /*Clear_OOB_K=*/false>(
|
|
|
|
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int m = 0; m < size<1>(tOgOaccum); ++m) {
|
|
|
|
+ const int row = get<0>(tOcO(0, m, 0));
|
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
|
|
|
|
+ get<1>(tOcO(0, m, 0)) == 0) {
|
|
|
|
+ gLSEaccum(row) = Split ? -INFINITY : INFINITY;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // We iterate over the blocks in reverse order. This is because the last block
|
|
|
|
+ // is the only one that needs masking when we read K and V from global memory.
|
|
|
|
+ // Moreover, iterating in reverse might save us 1 register (we just need
|
|
|
|
+ // n_block instead of both n_block and n_block_max).
|
|
|
|
+
|
|
|
|
+ // We move K and V to the last block.
|
|
|
|
+ const int bidb_cache =
|
|
|
|
+ params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
|
|
|
|
+ const int* block_table =
|
|
|
|
+ params.block_table == nullptr
|
|
|
|
+ ? nullptr
|
|
|
|
+ : params.block_table + bidb * params.block_table_batch_stride;
|
|
|
|
+ const index_t row_offset_k =
|
|
|
|
+ block_table == nullptr
|
|
|
|
+ ? binfo.k_offset(params.k_batch_stride, params.k_row_stride,
|
|
|
|
+ bidb_cache) +
|
|
|
|
+ (n_block_max - 1) * kBlockN * params.k_row_stride +
|
|
|
|
+ (bidh / params.h_h_k_ratio) * params.k_head_stride
|
|
|
|
+ : (bidh / params.h_h_k_ratio) *
|
|
|
|
+ params.k_head_stride; // block addresses are later resolved
|
|
|
|
+ // per-thread
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_v =
|
|
|
|
+ block_table == nullptr
|
|
|
|
+ ? binfo.k_offset(params.v_batch_stride, params.v_row_stride,
|
|
|
|
+ bidb_cache) +
|
|
|
|
+ (n_block_max - 1) * kBlockN * params.v_row_stride +
|
|
|
|
+ (bidh / params.h_h_k_ratio) * params.v_head_stride
|
|
|
|
+ : (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
|
|
|
+
|
|
|
|
+ Tensor mQ =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
|
|
|
|
+ binfo.q_offset(params.q_batch_stride,
|
|
|
|
+ params.q_row_stride, bidb)),
|
|
|
|
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
|
|
|
|
+ make_stride(params.q_row_stride, params.q_head_stride, _1{}));
|
|
|
|
+ Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
|
|
|
+ Tensor gK = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.k_row_stride, _1{}));
|
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr
|
|
|
|
+ // = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k,
|
|
|
|
+ // gK.data()); }
|
|
|
|
+ Tensor gV = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.v_row_stride, _1{}));
|
|
|
|
+ Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
|
|
|
|
+ typename Kernel_traits::SmemLayoutQ{});
|
|
|
|
+ Tensor sK =
|
|
|
|
+ make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
|
|
|
|
+ Tensor sV =
|
|
|
|
+ make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
|
|
|
+ Tensor sVt =
|
|
|
|
+ make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
|
|
|
+ Tensor sVtNoSwizzle =
|
|
|
|
+ make_tensor(sV.data().get(),
|
|
|
|
+ typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
|
|
|
|
+ auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
|
|
|
|
+ auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
|
|
|
|
+
|
|
|
|
+ Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
|
|
|
+ Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
|
|
|
+
|
|
|
|
+ Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
|
|
|
+ Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
|
|
|
|
+ Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
|
|
|
+ Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);
|
|
|
|
+
|
|
|
|
+ Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout()));
|
|
|
|
+ Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout()));
|
|
|
|
+ Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout()));
|
|
|
|
+ Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
|
|
|
|
+
|
|
|
|
+ if (block_table != nullptr) {
|
|
|
|
+ tKgK.data() =
|
|
|
|
+ gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block_max, params.page_block_size, block_table,
|
|
|
|
+ params.k_batch_stride, params.k_row_stride);
|
|
|
|
+ tVgV.data() =
|
|
|
|
+ gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block_max, params.page_block_size, block_table,
|
|
|
|
+ params.v_batch_stride, params.v_row_stride);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ typename Kernel_traits::TiledMma tiled_mma;
|
|
|
|
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
|
|
|
+ Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
|
|
|
+ Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
|
|
|
+ Tensor tOrVt =
|
|
|
|
+ thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
|
|
|
+
|
|
|
|
+ Tensor acc_o = partition_fragment_C(
|
|
|
|
+ tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
|
|
|
+
|
|
|
|
+ //
|
|
|
|
+ // Copy Atom retiling
|
|
|
|
+ //
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_Q =
|
|
|
|
+ make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
|
|
|
+ Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_K =
|
|
|
|
+ make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
|
|
|
+ Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
|
|
|
+
|
|
|
|
+ auto smem_tiled_copy_V = make_tiled_copy_B(
|
|
|
|
+ typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
|
|
|
+
|
|
|
|
+ // PREDICATES
|
|
|
|
+ //
|
|
|
|
+
|
|
|
|
+ // // Allocate predicate tensors for m and n
|
|
|
|
+ // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
|
|
|
|
+ // Stride<_1,_0>{}); Tensor tKVpKV =
|
|
|
|
+ // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
|
|
|
|
+ // Stride<_1,_0>{});
|
|
|
|
+
|
|
|
|
+ // Construct identity layout for sQ and sK
|
|
|
|
+ Tensor cQ = make_identity_tensor(
|
|
|
|
+ make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor cKV = make_identity_tensor(
|
|
|
|
+ make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
|
|
|
+
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tQcQ =
|
|
|
|
+ gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(
|
|
|
|
+ cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
|
|
|
+ Tensor tKVcKV =
|
|
|
|
+ make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));
|
|
|
|
+
|
|
|
|
+ // Allocate predicate tensors for k
|
|
|
|
+ Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
|
|
|
+ Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
|
|
|
+
|
|
|
|
+ // Set predicates for k bounds
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tQpQ); ++k) {
|
|
|
|
+ tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tKVpKV); ++k) {
|
|
|
|
+ tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Prologue
|
|
|
|
+
|
|
|
|
+ // Copy from Knew to K, optionally apply rotary embedding.
|
|
|
|
+ if constexpr (Append_KV) {
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
|
|
|
|
+ auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyRotcossinContPaged
|
|
|
|
+ gmem_tiled_copy_rotary_cont;
|
|
|
|
+ auto gmem_thr_copy_rotary_cont =
|
|
|
|
+ gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
|
|
|
|
+
|
|
|
|
+ // Even if we have MQA / GQA, all threadblocks responsible for the same KV
|
|
|
|
+ // head are writing to gmem. Technically it's a race condition, but they all
|
|
|
|
+ // write the same content anyway, and it's safe. We want to do this so that
|
|
|
|
+ // all threadblocks can proceed right after they finish writing the KV
|
|
|
|
+ // cache.
|
|
|
|
+ const index_t row_offset_cossin =
|
|
|
|
+ ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
|
|
|
|
+ Tensor gCos = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
|
|
|
+ make_stride(params.rotary_dim / 2, _1{}));
|
|
|
|
+ Tensor gSin = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
|
|
|
+ make_stride(params.rotary_dim / 2, _1{}));
|
|
|
|
+ Tensor gCosCont = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.rotary_dim / 2, _1{}));
|
|
|
|
+ Tensor gSinCont = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.rotary_dim / 2, _1{}));
|
|
|
|
+
|
|
|
|
+ Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
|
|
|
|
+ Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
|
|
|
|
+ Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
|
|
|
|
+ Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
|
|
|
|
+
|
|
|
|
+ Tensor tRgCos =
|
|
|
|
+ make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
|
|
|
|
+ Tensor tRgSin =
|
|
|
|
+ make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
|
|
|
|
+ Tensor tRgCosCont = make_tensor(
|
|
|
|
+ tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
|
|
|
|
+ Tensor tRgSinCont = make_tensor(
|
|
|
|
+ tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));
|
|
|
|
+
|
|
|
|
+ // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p,
|
|
|
|
+ // tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr,
|
|
|
|
+ // gCos.data(), tRgCos.data(), params.rotary_dim); } if (cute::thread(8, 0))
|
|
|
|
+ // { print_tensor(gCos); } if (cute::thread(0, 0)) { print_tensor(tRgCos); }
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_knew =
|
|
|
|
+ binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) +
|
|
|
|
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride +
|
|
|
|
+ (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
|
|
|
+ const index_t row_offset_vnew =
|
|
|
|
+ binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) +
|
|
|
|
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride +
|
|
|
|
+ (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
|
|
|
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew
|
|
|
|
+ // "line up". When we access them, e.g. if gK has 128 rows and gKnew has 64
|
|
|
|
+ // rows, we access gK[:128] and gKNew[128:128 + 64]. This maps to accessing
|
|
|
|
+ // the first 64 rows of knew_ptr.
|
|
|
|
+ Tensor gKnew = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.knew_ptr) +
|
|
|
|
+ row_offset_knew -
|
|
|
|
+ binfo.seqlen_k_cache * params.knew_row_stride),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.knew_row_stride, _1{}));
|
|
|
|
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
|
|
|
+ // printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n",
|
|
|
|
+ // params.knew_ptr, row_offset_knew, gKnew.data()); }
|
|
|
|
+ Tensor gVnew = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.vnew_ptr) +
|
|
|
|
+ row_offset_vnew -
|
|
|
|
+ binfo.seqlen_k_cache * params.vnew_row_stride),
|
|
|
|
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(params.vnew_row_stride, _1{}));
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
|
|
|
|
+ auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
|
|
|
|
+ Tensor tKgKnew_ =
|
|
|
|
+ gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
|
|
|
+ Tensor tVgVnew_ =
|
|
|
|
+ gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
|
|
|
+
|
|
|
|
+ auto tKgKnew =
|
|
|
|
+ make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout()));
|
|
|
|
+ auto tVgVnew =
|
|
|
|
+ make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout()));
|
|
|
|
+
|
|
|
|
+ const int n_block_copy_min =
|
|
|
|
+ std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
|
|
|
|
+ auto tKgK_data = tKgK.data();
|
|
|
|
+ auto tVgV_data = tVgV.data();
|
|
|
|
+ for (int n_block = n_block_max - 1; n_block >= n_block_copy_min;
|
|
|
|
+ n_block--) {
|
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
|
+ tVgVnew, tVgV, tKVcKV, tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
|
+ binfo.seqlen_k_cache - n_block * kBlockN);
|
|
|
|
+ tVgVnew.data() =
|
|
|
|
+ tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
|
|
|
+ if (params.rotary_dim == 0) {
|
|
|
|
+ flash::copy_w_min_idx<Is_even_K>(
|
|
|
|
+ tKgKnew, tKgK, tKVcKV, tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
|
+ binfo.seqlen_k_cache - n_block * kBlockN);
|
|
|
|
+ } else {
|
|
|
|
+ if (params.is_rotary_interleaved) {
|
|
|
|
+ // Don't clear OOB_K because we're writing to global memory
|
|
|
|
+ flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
|
+ tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
|
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d,
|
|
|
|
+ params.rotary_dim);
|
|
|
|
+ tRgCos.data() =
|
|
|
|
+ tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
|
|
|
|
+ tRgSin.data() =
|
|
|
|
+ tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
|
|
|
|
+ } else {
|
|
|
|
+ // Don't clear OOB_K because we're writing to global memory
|
|
|
|
+ flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
|
+ tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
|
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d,
|
|
|
|
+ params.rotary_dim);
|
|
|
|
+ tRgCosCont.data() =
|
|
|
|
+ tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
|
|
|
|
+ tRgSinCont.data() =
|
|
|
|
+ tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ tKgKnew.data() =
|
|
|
|
+ tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
|
|
|
+ if (block_table == nullptr) {
|
|
|
|
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
|
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
|
+ } else {
|
|
|
|
+ if (n_block > n_block_copy_min) {
|
|
|
|
+ tVgV.data() =
|
|
|
|
+ gV.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block, params.page_block_size, block_table,
|
|
|
|
+ params.v_batch_stride, params.v_row_stride);
|
|
|
|
+ tKgK.data() =
|
|
|
|
+ gK.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block, params.page_block_size, block_table,
|
|
|
|
+ params.k_batch_stride, params.k_row_stride);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Need this before we can read in K again, so that we'll see the updated K
|
|
|
|
+ // values.
|
|
|
|
+ __syncthreads();
|
|
|
|
+ tKgK.data() = tKgK_data;
|
|
|
|
+ tVgV.data() = tVgV_data;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Read Q from gmem to smem, optionally apply rotary embedding.
|
|
|
|
+ if (!Append_KV || params.rotary_dim == 0) {
|
|
|
|
+ // We don't need to clear the sQ smem tiles since we'll only write out the
|
|
|
|
+ // valid outputs
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K>(
|
|
|
|
+ gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+ } else {
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
|
|
|
|
+ auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
|
|
|
|
+ typename Kernel_traits::GmemTiledCopyRotcossinCont
|
|
|
|
+ gmem_tiled_copy_rotary_cont;
|
|
|
|
+ auto gmem_thr_copy_rotary_cont =
|
|
|
|
+ gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
|
|
|
|
+ const index_t row_offset_cossin =
|
|
|
|
+ (binfo.seqlen_k_cache +
|
|
|
|
+ (Is_causal || Is_local ? m_block * kBlockM : 0)) *
|
|
|
|
+ (params.rotary_dim / 2);
|
|
|
|
+ // If not causal, all the queries get the same the cos/sin, taken at
|
|
|
|
+ // location seqlen_k_cache. We do this by setting the row stride of gCos /
|
|
|
|
+ // gSin to 0.
|
|
|
|
+ Tensor gCos = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
|
|
|
|
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
|
|
|
|
+ Tensor gSin = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
|
|
|
|
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
|
|
|
|
+ Tensor gCosCont = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
|
|
|
|
+ Tensor gSinCont = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
|
|
|
|
+ row_offset_cossin),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
|
|
|
|
+ Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
|
|
|
|
+ Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
|
|
|
|
+ Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
|
|
|
|
+ Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
|
|
|
|
+ if (params.is_rotary_interleaved) {
|
|
|
|
+ flash::copy_rotary_interleaved<Is_even_K>(
|
|
|
|
+ tQgQ, tQsQ, tRgCos, tRgSin, tQcQ,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
|
|
|
|
+ params.rotary_dim);
|
|
|
|
+ } else {
|
|
|
|
+ flash::copy_rotary_contiguous<Is_even_K>(
|
|
|
|
+ tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
|
|
|
|
+ params.rotary_dim);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ int n_block = n_block_max - 1;
|
|
|
|
+ // We don't need to clear the sK smem tiles since we'll mask out the scores
|
|
|
|
+ // anyway.
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV,
|
|
|
|
+ tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+
|
|
|
|
+ // flash::cp_async_wait<0>();
|
|
|
|
+ // __syncthreads();
|
|
|
|
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
|
|
|
|
+ // __syncthreads();
|
|
|
|
+
|
|
|
|
+ clear(acc_o);
|
|
|
|
+
|
|
|
|
+ flash::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
|
+
|
|
|
|
+ const float alibi_slope =
|
|
|
|
+ !Has_alibi ? 0.0f
|
|
|
|
+ : reinterpret_cast<float*>(params.alibi_slopes_ptr)
|
|
|
|
+ [bidb * params.alibi_slopes_batch_stride + bidh] /
|
|
|
|
+ params.scale_softmax;
|
|
|
|
+ flash::Mask<Is_causal, Is_local, Has_alibi> mask(
|
|
|
|
+ binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
|
|
|
|
+ params.window_size_right, alibi_slope);
|
|
|
|
+
|
|
|
|
+ // For performance reason, we separate out two kinds of iterations:
|
|
|
|
+ // those that need masking on S, and those that don't.
|
|
|
|
+ // We need masking on S for the very last block when K and V has length not
|
|
|
|
+ // multiple of kBlockN. We also need masking on S if it's causal, for the last
|
|
|
|
+ // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
|
|
|
|
+ // iteration.
|
|
|
|
+
|
|
|
|
+ // If not even_N, then seqlen_k might end in the middle of a block. In that
|
|
|
|
+ // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
|
|
|
+ constexpr int n_masking_steps =
|
|
|
|
+ (!Is_causal && !Is_local)
|
|
|
|
+ ? 1
|
|
|
|
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
|
|
|
|
+ : cute::ceil_div(kBlockM, kBlockN) + 1);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int masking_step = 0; masking_step < n_masking_steps;
|
|
|
|
+ ++masking_step, --n_block) {
|
|
|
|
+ Tensor acc_s = partition_fragment_C(
|
|
|
|
+ tiled_mma,
|
|
|
|
+ Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
|
+ clear(acc_s);
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+
|
|
|
|
+ // Advance gV
|
|
|
|
+ if (masking_step > 0) {
|
|
|
|
+ if (block_table == nullptr) {
|
|
|
|
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
|
+ } else {
|
|
|
|
+ tVgV.data() =
|
|
|
|
+ gV.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block + 1, params.page_block_size, block_table,
|
|
|
|
+ params.v_batch_stride, params.v_row_stride);
|
|
|
|
+ }
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV,
|
|
|
|
+ tVsV, tKVcKV, tKVpKV);
|
|
|
|
+ } else {
|
|
|
|
+ // Clear the smem tiles to account for predicated off loads
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
|
+ gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV,
|
|
|
|
+ binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
|
+ }
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+
|
|
|
|
+ flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
|
|
|
|
+ smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
|
|
|
|
+ // if (cute::thread0()) { print(acc_s); }
|
|
|
|
+ if constexpr (Is_softcap) {
|
|
|
|
+ apply_softcap(acc_s, params.softcap);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ mask.template apply_mask<Is_causal, Is_even_MN>(
|
|
|
|
+ acc_s, n_block * kBlockN,
|
|
|
|
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
|
|
|
+
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
|
|
|
|
+ // __syncthreads();
|
|
|
|
+
|
|
|
|
+ if (n_block > n_block_min) {
|
|
|
|
+ // Advance gK
|
|
|
|
+ if (block_table == nullptr) {
|
|
|
|
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
|
+ } else {
|
|
|
|
+ tKgK.data() = gK.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block, params.page_block_size, block_table,
|
|
|
|
+ params.k_batch_stride, params.k_row_stride);
|
|
|
|
+ }
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
|
|
|
|
+ tKsK, tKVcKV, tKVpKV);
|
|
|
|
+ // This cp_async_fence needs to be in the if block, otherwise the
|
|
|
|
+ // synchronization isn't right and we get race conditions.
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // We have key_padding_mask so we'll need to Check_inf
|
|
|
|
+ masking_step == 0
|
|
|
|
+ ? softmax.template softmax_rescale_o</*Is_first=*/true,
|
|
|
|
+ /*Check_inf=*/Is_causal ||
|
|
|
|
+ Is_local || !Is_even_MN>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2)
|
|
|
|
+ : softmax.template softmax_rescale_o</*Is_first=*/false,
|
|
|
|
+ /*Check_inf=*/Is_causal ||
|
|
|
|
+ Is_local || !Is_even_MN>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
+ // if (cute::thread0()) { print(scores_max); print(scores_sum);
|
|
|
|
+ // print(scores); }
|
|
|
|
+
|
|
|
|
+ // Convert acc_s from fp32 to fp16/bf16
|
|
|
|
+ Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
|
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
|
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
|
+ Tensor tOrP = make_tensor(
|
|
|
|
+ rP.data(),
|
|
|
|
+ flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
+
|
|
|
|
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
|
|
|
|
+ smem_thr_copy_V);
|
|
|
|
+
|
|
|
|
+ // This check is at the end of the loop since we always have at least 1
|
|
|
|
+ // iteration
|
|
|
|
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
|
|
|
|
+ --n_block;
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // These are the iterations where we don't need masking on S
|
|
|
|
+ for (; n_block >= n_block_min; --n_block) {
|
|
|
|
+ Tensor acc_s = partition_fragment_C(
|
|
|
|
+ tiled_mma,
|
|
|
|
+ Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
|
+ clear(acc_s);
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ // Advance gV
|
|
|
|
+ if (block_table == nullptr) {
|
|
|
|
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
|
|
|
+ } else {
|
|
|
|
+ tVgV.data() = gV.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block + 1, params.page_block_size, block_table,
|
|
|
|
+ params.v_batch_stride, params.v_row_stride);
|
|
|
|
+ }
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV,
|
|
|
|
+ tKVcKV, tKVpKV);
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+
|
|
|
|
+ flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
|
|
|
|
+ smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
|
|
|
|
+ if constexpr (Is_softcap) {
|
|
|
|
+ apply_softcap(acc_s, params.softcap);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ flash::cp_async_wait<0>();
|
|
|
|
+ __syncthreads();
|
|
|
|
+ if (n_block > n_block_min) {
|
|
|
|
+ // Advance gK
|
|
|
|
+ if (block_table == nullptr) {
|
|
|
|
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
|
|
|
+ } else {
|
|
|
|
+ tKgK.data() = gK.data() +
|
|
|
|
+ flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
|
|
|
|
+ tidx, n_block, params.page_block_size, block_table,
|
|
|
|
+ params.k_batch_stride, params.k_row_stride);
|
|
|
|
+ }
|
|
|
|
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
|
|
|
|
+ tKsK, tKVcKV, tKVpKV);
|
|
|
|
+ // This cp_async_fence needs to be in the if block, otherwise the
|
|
|
|
+ // synchronization isn't right and we get race conditions.
|
|
|
|
+ cute::cp_async_fence();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ mask.template apply_mask</*Causal_mask=*/false>(
|
|
|
|
+ acc_s, n_block * kBlockN,
|
|
|
|
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
|
|
|
+ softmax
|
|
|
|
+ .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
|
|
|
|
+ acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
+
|
|
|
|
+ Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
|
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
|
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
|
+ Tensor tOrP = make_tensor(
|
|
|
|
+ rP.data(),
|
|
|
|
+ flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
+
|
|
|
|
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
|
|
|
|
+ smem_thr_copy_V);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Epilogue
|
|
|
|
+
|
|
|
|
+ Tensor lse =
|
|
|
|
+ softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(
|
|
|
|
+ acc_o, params.scale_softmax);
|
|
|
|
+ // if (cute::thread0()) { print(lse); }
|
|
|
|
+
|
|
|
|
+ Tensor sOaccum =
|
|
|
|
+ make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)),
|
|
|
|
+ typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
|
|
+ // Partition sO to match the accumulator partitioning
|
|
|
|
+ using SmemTiledCopyO =
|
|
|
|
+ std::conditional_t<!Split, typename Kernel_traits::SmemCopyAtomO,
|
|
|
|
+ typename Kernel_traits::SmemCopyAtomOaccum>;
|
|
|
|
+ auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
|
|
|
|
+ auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
|
+ Tensor rO = flash::convert_type<ElementO>(acc_o);
|
|
|
|
+ Tensor taccOrOaccum =
|
|
|
|
+ smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
|
|
|
+ Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(
|
|
|
|
+ sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
|
|
+
|
|
|
|
+ // sOaccum is larger than sQ, so we need to syncthreads here
|
|
|
|
+ // TODO: allocate enough smem for sOaccum
|
|
|
|
+ if constexpr (Split) {
|
|
|
|
+ __syncthreads();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_o =
|
|
|
|
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
|
|
|
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
|
|
|
+ const index_t row_offset_oaccum =
|
|
|
|
+ (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
|
|
|
|
+ m_block * kBlockM) *
|
|
|
|
+ params.d_rounded;
|
|
|
|
+ const index_t row_offset_lseaccum =
|
|
|
|
+ (Split || !params.unpadded_lse
|
|
|
|
+ ? ((n_split_idx * params.b + bidb) * params.h + bidh) *
|
|
|
|
+ params.seqlen_q
|
|
|
|
+ : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) +
|
|
|
|
+ m_block * kBlockM;
|
|
|
|
+
|
|
|
|
+ Tensor gOaccum =
|
|
|
|
+ make_tensor(make_gmem_ptr(reinterpret_cast<ElementO*>(
|
|
|
|
+ Split ? params.oaccum_ptr : params.o_ptr) +
|
|
|
|
+ (Split ? row_offset_oaccum : row_offset_o)),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
|
|
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
|
|
|
|
+ Tensor gLSEaccum = make_tensor(
|
|
|
|
+ make_gmem_ptr(
|
|
|
|
+ reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
|
|
|
|
+ : params.softmax_lse_ptr) +
|
|
|
|
+ row_offset_lseaccum),
|
|
|
|
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
+ // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n",
|
|
|
|
+ // row_offset_o, bidh, gOaccum.data()); }
|
|
|
|
+
|
|
|
|
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
|
|
|
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(
|
|
|
|
+ sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
|
|
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
|
|
|
+
|
|
|
|
+ __syncthreads();
|
|
|
|
+
|
|
|
|
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
|
|
|
+ cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
|
|
|
+
|
|
|
|
+ Tensor caccO = make_identity_tensor(
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
|
|
|
+ static_assert(decltype(size<0>(taccOcO))::value == 4);
|
|
|
|
+ // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
|
|
|
+ Tensor taccOcO_row =
|
|
|
|
+ logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
|
|
|
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
|
|
|
+ if (get<1>(taccOcO_row(0)) == 0) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int mi = 0; mi < size(lse); ++mi) {
|
|
|
|
+ const int row = get<0>(taccOcO_row(mi));
|
|
|
|
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
|
|
|
|
+ gLSEaccum(row) = lse(mi);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Construct identity layout for sO
|
|
|
|
+ Tensor cO = make_identity_tensor(make_shape(
|
|
|
|
+ size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(
|
|
|
|
+ cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
|
|
|
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tOpO); ++k) {
|
|
|
|
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
|
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
|
|
|
|
+ /*Clear_OOB_K=*/false>(gmem_tiled_copy_Oaccum, tOrOaccum,
|
|
|
|
+ tOgOaccum, tOcO, tOpO,
|
|
|
|
+ binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
|
|
|
|
+ bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
|
|
|
|
+ bool Is_softcap, bool Return_softmax, typename Params>
|
|
|
|
+inline __device__ void compute_attn(const Params& params) {
|
|
|
|
+ const int m_block = blockIdx.x;
|
|
|
|
+ // The block index for the batch.
|
|
|
|
+ const int bidb = blockIdx.y;
|
|
|
|
+ // The block index for the head.
|
|
|
|
+ const int bidh = blockIdx.z;
|
|
|
|
+
|
|
|
|
+ // We want the fwd and bwd to generate the same dropout pattern (RNG), without
|
|
|
|
+ // restricting them to have the same number of threads or have to traverse the
|
|
|
|
+ // attention matrix in the same order. In the Philox RNG, we use the offset to
|
|
|
|
+ // store the batch, head, and the lane id (within a warp). We use the
|
|
|
|
+ // subsequence to store the location of the 16 x 32 blocks within the
|
|
|
|
+ // attention matrix. This way, as long as we have the batch, head, and the
|
|
|
|
+ // location of the 16 x 32 block within the attention matrix, we can generate
|
|
|
|
+ // the exact same dropout pattern.
|
|
|
|
+
|
|
|
|
+ flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local,
|
|
|
|
+ Has_alibi, Is_even_MN, Is_even_K, Is_softcap,
|
|
|
|
+ Return_softmax>(params, bidb, bidh, m_block);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
|
|
|
|
+ bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
|
|
|
|
+ bool Append_KV, typename Params>
|
|
|
|
+inline __device__ void compute_attn_splitkv(const Params& params) {
|
|
|
|
+ const int m_block = blockIdx.x;
|
|
|
|
+ // The block index for the batch.
|
|
|
|
+ const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
|
|
|
|
+ // The block index for the head.
|
|
|
|
+ const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
|
|
|
|
+ const int n_split_idx = Split ? blockIdx.y : 0;
|
|
|
|
+ const int num_n_splits = Split ? gridDim.y : 1;
|
|
|
|
+ flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local,
|
|
|
|
+ Has_alibi, Is_even_MN, Is_even_K,
|
|
|
|
+ Is_softcap, Split, Append_KV>(
|
|
|
|
+ params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+template <typename Kernel_traits, int kBlockM, int Log_max_splits,
|
|
|
|
+ bool Is_even_K, typename Params>
|
|
|
|
+inline __device__ void combine_attn_seqk_parallel(const Params& params) {
|
|
|
|
+ using Element = typename Kernel_traits::Element;
|
|
|
|
+ using ElementAccum = typename Kernel_traits::ElementAccum;
|
|
|
|
+ using index_t = typename Kernel_traits::index_t;
|
|
|
|
+ constexpr int kMaxSplits = 1 << Log_max_splits;
|
|
|
|
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
|
|
|
+ constexpr int kNThreads = Kernel_traits::kNThreads;
|
|
|
|
+
|
|
|
|
+ static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
|
|
|
|
+ static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32,
|
|
|
|
+ "kBlockM must be 4, 8, 16 or 32");
|
|
|
|
+ static_assert(kNThreads == 128, "We assume that each block has 128 threads");
|
|
|
|
+
|
|
|
|
+ // Shared memory.
|
|
|
|
+ // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
|
|
|
|
+ __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
|
|
|
|
+
|
|
|
|
+ // The thread and block index.
|
|
|
|
+ const int tidx = threadIdx.x;
|
|
|
|
+ const int bidx = blockIdx.x;
|
|
|
|
+
|
|
|
|
+ const index_t lse_size = params.b * params.h * params.seqlen_q;
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_lse = bidx * kBlockM;
|
|
|
|
+ Tensor gLSEaccum = make_tensor(
|
|
|
|
+ make_gmem_ptr(
|
|
|
|
+ reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) +
|
|
|
|
+ row_offset_lse),
|
|
|
|
+ Shape<Int<kMaxSplits>, Int<kBlockM>>{}, make_stride(lse_size, _1{}));
|
|
|
|
+
|
|
|
|
+ // LSE format is different depending on params.unpadded_lse and
|
|
|
|
+ // params.seqlenq_ngroups_swapped, see comment in get_lse_tile. This tensor's
|
|
|
|
+ // layout maps row_offset_lse to {bidb, bidh, q_offset}.
|
|
|
|
+ Tensor gLSE = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) +
|
|
|
|
+ row_offset_lse),
|
|
|
|
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
|
|
|
|
+
|
|
|
|
+ // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb,
|
|
|
|
+ // q_offset}.
|
|
|
|
+ Layout flat_layout = make_layout(lse_size);
|
|
|
|
+ Layout orig_layout =
|
|
|
|
+ make_layout(make_shape(params.seqlen_q, params.h, params.b));
|
|
|
|
+ auto transposed_stride =
|
|
|
|
+ params.seqlenq_ngroups_swapped
|
|
|
|
+ ? make_stride(params.b, params.seqlen_q * params.b, 1)
|
|
|
|
+ : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
|
|
|
|
+ Layout remapped_layout = make_layout(
|
|
|
|
+ make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
|
|
|
|
+ Layout final_layout = cute::composition(
|
|
|
|
+ remapped_layout, cute::composition(orig_layout, flat_layout));
|
|
|
|
+
|
|
|
|
+ Tensor gLSE_unpadded = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
|
|
|
|
+ final_layout);
|
|
|
|
+
|
|
|
|
+ constexpr int kNLsePerThread =
|
|
|
|
+ (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
|
|
|
|
+
|
|
|
|
+ // Read the LSE values from gmem and store them in shared memory, then
|
|
|
|
+ // transpose them.
|
|
|
|
+ constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int l = 0; l < kNLsePerThread; ++l) {
|
|
|
|
+ const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
|
|
|
|
+ const int col = tidx % kBlockM;
|
|
|
|
+ ElementAccum lse =
|
|
|
|
+ (row < params.num_splits && col < lse_size - bidx * kBlockM)
|
|
|
|
+ ? gLSEaccum(row, col)
|
|
|
|
+ : -INFINITY;
|
|
|
|
+ if (row < kMaxSplits) {
|
|
|
|
+ sLSE[row][col] = lse;
|
|
|
|
+ }
|
|
|
|
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
|
|
|
|
+ // = %f\n", tidx, row, col, lse); }
|
|
|
|
+ }
|
|
|
|
+ // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse =
|
|
|
|
+ // %f\n", tidx, row_offset_lse, lse_accum(0)); }
|
|
|
|
+ __syncthreads();
|
|
|
|
+ Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
|
|
|
|
+ constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
|
|
|
|
+ // To make sure that kMaxSplits is within 1 warp: we decide how many elements
|
|
|
|
+ // within kMaxSplits each thread should hold. If kMaxSplits = 16, then each
|
|
|
|
+ // thread holds 2 elements (128 threads, kBlockM rows, so each time we load we
|
|
|
|
+ // can load 128 / kBlockM rows). constexpr int kThreadsPerSplit = kMaxSplits /
|
|
|
|
+ // kRowsPerLoadTranspose; static_assert(kThreadsPerSplit <= 32);
|
|
|
|
+ static_assert(kRowsPerLoadTranspose <= 32);
|
|
|
|
+ static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int l = 0; l < kNLsePerThread; ++l) {
|
|
|
|
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
|
|
|
+ const int col = tidx / kRowsPerLoadTranspose;
|
|
|
|
+ lse_accum(l) =
|
|
|
|
+ (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
|
|
|
|
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
|
|
|
|
+ // = %f\n", tidx, row, col, lse_accum(l)); }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Compute the logsumexp of the LSE along the split dimension.
|
|
|
|
+ ElementAccum lse_max = lse_accum(0);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int l = 1; l < kNLsePerThread; ++l) {
|
|
|
|
+ lse_max = max(lse_max, lse_accum(l));
|
|
|
|
+ }
|
|
|
|
+ MaxOp<float> max_op;
|
|
|
|
+ lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
|
|
|
+ lse_max =
|
|
|
|
+ lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
|
|
|
+ float lse_sum = expf(lse_accum(0) - lse_max);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int l = 1; l < kNLsePerThread; ++l) {
|
|
|
|
+ lse_sum += expf(lse_accum(l) - lse_max);
|
|
|
|
+ }
|
|
|
|
+ SumOp<float> sum_op;
|
|
|
|
+ lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
|
|
|
|
+ // For the case where all local lse == -INFINITY, we want to set lse_logsum to
|
|
|
|
+ // INFINITY. Otherwise lse_logsum is log(0.0) = -INFINITY and we get NaN when
|
|
|
|
+ // we do lse_accum(l) - lse_logsum.
|
|
|
|
+ ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum)
|
|
|
|
+ ? INFINITY
|
|
|
|
+ : logf(lse_sum) + lse_max;
|
|
|
|
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f,
|
|
|
|
+ // lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
|
|
|
+ if (tidx % kRowsPerLoadTranspose == 0 &&
|
|
|
|
+ tidx / kRowsPerLoadTranspose < kBlockM) {
|
|
|
|
+ if (params.unpadded_lse) {
|
|
|
|
+ const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
|
|
|
|
+ if (lse_offset < lse_size) {
|
|
|
|
+ gLSE_unpadded(lse_offset) = lse_logsum;
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+// Store the scales exp(lse - lse_logsum) in shared memory.
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int l = 0; l < kNLsePerThread; ++l) {
|
|
|
|
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
|
|
|
+ const int col = tidx / kRowsPerLoadTranspose;
|
|
|
|
+ if (row < params.num_splits && col < kBlockM) {
|
|
|
|
+ sLSE[row][col] = expf(lse_accum(l) - lse_logsum);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ __syncthreads();
|
|
|
|
+
|
|
|
|
+ const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
|
|
|
|
+ Tensor gOaccum = make_tensor(
|
|
|
|
+ make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.oaccum_ptr) +
|
|
|
|
+ row_offset_oaccum),
|
|
|
|
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
|
|
|
+ constexpr int kBlockN = kNThreads / kBlockM;
|
|
|
|
+ using GmemLayoutAtomOaccum =
|
|
|
|
+ Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
|
|
|
|
+ using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
|
|
|
+ Copy_Atom<DefaultCopy, ElementAccum>{}, GmemLayoutAtomOaccum{},
|
|
|
|
+ Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
|
|
|
+ GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
|
|
|
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
|
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
|
|
|
+ Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
|
|
+ Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
|
|
|
+ clear(tOrO);
|
|
|
|
+
|
|
|
|
+ // Predicates
|
|
|
|
+ Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
|
|
|
+ // Repeat the partitioning with identity layouts
|
|
|
|
+ Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
|
|
|
|
+ Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
|
|
|
+ if (!Is_even_K) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size(tOpOaccum); ++k) {
|
|
|
|
+ tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // Load Oaccum in then scale and accumulate to O
|
|
|
|
+ for (int split = 0; split < params.num_splits; ++split) {
|
|
|
|
+ flash::copy</*Is_even_MN=*/false, Is_even_K>(
|
|
|
|
+ gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum,
|
|
|
|
+ params.b * params.h * params.seqlen_q - bidx * kBlockM);
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int m = 0; m < size<1>(tOrOaccum); ++m) {
|
|
|
|
+ int row = get<0>(tOcOaccum(0, m, 0));
|
|
|
|
+ ElementAccum lse_scale = sLSE[split][row];
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size<2>(tOrOaccum); ++k) {
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int i = 0; i < size<0>(tOrOaccum); ++i) {
|
|
|
|
+ tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0],
|
|
|
|
+ // sLSE[split][1]); print(tOrOaccum); }
|
|
|
|
+ }
|
|
|
|
+ tOgOaccum.data() = tOgOaccum.data() +
|
|
|
|
+ params.b * params.h * params.seqlen_q * params.d_rounded;
|
|
|
|
+ }
|
|
|
|
+ // if (cute::thread0()) { print_tensor(tOrO); }
|
|
|
|
+
|
|
|
|
+ Tensor rO = flash::convert_type<Element>(tOrO);
|
|
|
|
+// Write to gO
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int m = 0; m < size<1>(rO); ++m) {
|
|
|
|
+ const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
|
|
|
|
+ if (idx < params.b * params.h * params.seqlen_q) {
|
|
|
|
+ const int batch_idx = idx / (params.h * params.seqlen_q);
|
|
|
|
+ const int head_idx =
|
|
|
|
+ (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
|
|
|
|
+ // The index to the rows of Q
|
|
|
|
+ const int row = idx - batch_idx * (params.h * params.seqlen_q) -
|
|
|
|
+ head_idx * params.seqlen_q;
|
|
|
|
+ auto o_ptr = reinterpret_cast<Element*>(params.o_ptr) +
|
|
|
|
+ batch_idx * params.o_batch_stride +
|
|
|
|
+ head_idx * params.o_head_stride + row * params.o_row_stride;
|
|
|
|
+#pragma unroll
|
|
|
|
+ for (int k = 0; k < size<2>(rO); ++k) {
|
|
|
|
+ if (Is_even_K || tOpOaccum(k)) {
|
|
|
|
+ const int col = get<1>(tOcOaccum(0, m, k));
|
|
|
|
+ Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
|
|
|
|
+ Shape<Int<decltype(size<0>(rO))::value>>{},
|
|
|
|
+ Stride<_1>{});
|
|
|
|
+ // TODO: Should check if this is using vectorized store, but it seems
|
|
|
|
+ // pretty fast
|
|
|
|
+ copy(rO(_, m, k), gO);
|
|
|
|
+ // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d,
|
|
|
|
+ // batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx,
|
|
|
|
+ // batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
|
|
|
|
+ // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] =
|
|
|
|
+ // recast<uint64_t>(rO)(0, m, k);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+} // namespace flash
|