|
@@ -4,6 +4,7 @@
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
+#include "namespace_config.h"
|
|
|
#include "philox_unpack.cuh" // For at::cuda::philox::unpack
|
|
|
|
|
|
#include <cute/tensor.hpp>
|
|
@@ -20,7 +21,7 @@
|
|
|
#include "dropout.h"
|
|
|
#include "rotary.h"
|
|
|
|
|
|
-namespace flash {
|
|
|
+namespace FLASH_NAMESPACE {
|
|
|
|
|
|
using namespace cute;
|
|
|
|
|
@@ -66,7 +67,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
constexpr int kNWarps = Kernel_traits::kNWarps;
|
|
|
|
|
|
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
|
|
|
- flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
|
|
|
+ FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
|
|
|
bidb, bidh, tidx, params.h);
|
|
|
|
|
|
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
|
|
@@ -115,7 +116,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
}
|
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
);
|
|
|
#pragma unroll
|
|
@@ -246,7 +247,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
// Prologue
|
|
|
|
|
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
|
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
|
|
|
|
@@ -255,7 +256,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
|
|
|
|
|
if (Kernel_traits::Share_Q_K_smem) {
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
@@ -265,14 +266,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
int n_block = n_block_max - 1;
|
|
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
|
|
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
|
|
|
binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
cute::cp_async_fence();
|
|
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
|
|
// __syncthreads();
|
|
|
|
|
|
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
|
|
- flash::cp_async_wait<1>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<1>();
|
|
|
__syncthreads();
|
|
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
|
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
|
@@ -281,10 +282,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
clear(acc_o);
|
|
|
|
|
|
- flash::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
+ FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
|
|
|
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
|
|
- flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
|
|
+ FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
|
|
|
|
|
// For performance reason, we separate out two kinds of iterations:
|
|
|
// those that need masking on S, and those that don't.
|
|
@@ -301,37 +302,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
|
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
clear(acc_s);
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
|
|
|
// Advance gV
|
|
|
if (masking_step > 0) {
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
} else {
|
|
|
// Clear the smem tiles to account for predicated off loads
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
|
);
|
|
|
}
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
+ FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
|
// if (cute::thread0()) { print(acc_s); }
|
|
|
if constexpr (Is_softcap){
|
|
|
- flash::apply_softcap(acc_s, params.softcap);
|
|
|
+ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
|
|
|
}
|
|
|
|
|
|
mask.template apply_mask<Is_causal, Is_even_MN>(
|
|
|
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
|
|
);
|
|
|
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
if (n_block > n_block_min) {
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -343,7 +344,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
|
|
// Convert acc_s from fp32 to fp16/bf16
|
|
|
- Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
+ Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
|
|
|
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
|
|
int block_col_idx = n_block * (kBlockN / 32);
|
|
|
if (Return_softmax) {
|
|
@@ -361,9 +362,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
+ Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
// if (cute::thread0()) { print(tOrP); }
|
|
|
- flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
+ FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
// if (cute::thread0()) { print(scores); }
|
|
|
|
|
|
// This check is at the end of the loop since we always have at least 1 iteration
|
|
@@ -377,23 +378,23 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
for (; n_block >= n_block_min; --n_block) {
|
|
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
clear(acc_s);
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
+ FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
|
if constexpr (Is_softcap){
|
|
|
- flash::apply_softcap(acc_s, params.softcap);
|
|
|
+ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
|
|
|
}
|
|
|
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
if (n_block > n_block_min) {
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -405,7 +406,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
|
|
- Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
+ Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
|
|
|
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
|
|
int block_col_idx = n_block * (kBlockN / 32);
|
|
|
if (Return_softmax) {
|
|
@@ -423,8 +424,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
|
|
|
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
- flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
+ Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
+ FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
}
|
|
|
|
|
|
// Epilogue
|
|
@@ -432,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
|
|
|
|
|
|
// Convert acc_o from fp32 to fp16/bf16
|
|
|
- Tensor rO = flash::convert_type<Element>(acc_o);
|
|
|
+ Tensor rO = FLASH_NAMESPACE::convert_type<Element>(acc_o);
|
|
|
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
|
// Partition sO to match the accumulator partitioning
|
|
|
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
|
@@ -487,7 +488,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
}
|
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
);
|
|
|
}
|
|
@@ -563,7 +564,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
}
|
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
);
|
|
|
#pragma unroll
|
|
@@ -730,18 +731,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
auto tKgK_data = tKgK.data();
|
|
|
auto tVgV_data = tVgV.data();
|
|
|
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(
|
|
|
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
);
|
|
|
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
|
|
if (params.rotary_dim == 0) {
|
|
|
- flash::copy_w_min_idx<Is_even_K>(
|
|
|
+ FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(
|
|
|
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
|
|
);
|
|
|
} else {
|
|
|
if (params.is_rotary_interleaved) {
|
|
|
// Don't clear OOB_K because we're writing to global memory
|
|
|
- flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
|
|
|
);
|
|
@@ -749,7 +750,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
|
|
|
} else {
|
|
|
// Don't clear OOB_K because we're writing to global memory
|
|
|
- flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
|
|
|
tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
|
|
|
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
|
|
|
);
|
|
@@ -784,7 +785,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
// Read Q from gmem to smem, optionally apply rotary embedding.
|
|
|
if (!Append_KV || params.rotary_dim == 0) {
|
|
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
|
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
|
binfo.actual_seqlen_q - m_block * kBlockM);
|
|
|
} else {
|
|
|
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
|
@@ -807,12 +808,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
|
|
|
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
|
|
|
if (params.is_rotary_interleaved) {
|
|
|
- flash::copy_rotary_interleaved<Is_even_K>(
|
|
|
+ FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K>(
|
|
|
tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
|
|
|
0, params.d, params.rotary_dim
|
|
|
);
|
|
|
} else {
|
|
|
- flash::copy_rotary_contiguous<Is_even_K>(
|
|
|
+ FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K>(
|
|
|
tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
|
|
|
0, params.d, params.rotary_dim
|
|
|
);
|
|
@@ -821,21 +822,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
|
|
|
int n_block = n_block_max - 1;
|
|
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
|
|
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
|
|
binfo.actual_seqlen_k - n_block * kBlockN);
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- // flash::cp_async_wait<0>();
|
|
|
+ // FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
// __syncthreads();
|
|
|
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
|
|
|
// __syncthreads();
|
|
|
|
|
|
clear(acc_o);
|
|
|
|
|
|
- flash::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
+ FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
|
|
|
|
|
|
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
|
|
- flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
|
|
+ FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
|
|
|
|
|
// For performance reason, we separate out two kinds of iterations:
|
|
|
// those that need masking on S, and those that don't.
|
|
@@ -852,7 +853,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
|
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
clear(acc_s);
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
|
|
|
// Advance gV
|
|
@@ -866,22 +867,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
|
|
|
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
|
|
|
}
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
} else {
|
|
|
// Clear the smem tiles to account for predicated off loads
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
|
|
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
|
|
);
|
|
|
}
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm(
|
|
|
+ FLASH_NAMESPACE::gemm(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
|
// if (cute::thread0()) { print(acc_s); }
|
|
|
if constexpr (Is_softcap){
|
|
|
- flash::apply_softcap(acc_s, params.softcap);
|
|
|
+ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
|
|
|
}
|
|
|
|
|
|
|
|
@@ -889,7 +890,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
|
|
);
|
|
|
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
|
|
|
// __syncthreads();
|
|
@@ -905,7 +906,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
|
|
|
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
|
|
|
}
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -918,12 +919,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
|
|
|
|
|
// Convert acc_s from fp32 to fp16/bf16
|
|
|
- Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
+ Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
|
|
|
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
+ Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
|
|
- flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
+ FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
|
|
|
// This check is at the end of the loop since we always have at least 1 iteration
|
|
|
if (n_masking_steps > 1 && n_block <= n_block_min) {
|
|
@@ -936,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
for (; n_block >= n_block_min; --n_block) {
|
|
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
|
|
clear(acc_s);
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
// Advance gV
|
|
|
if (block_table == nullptr) {
|
|
@@ -948,18 +949,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
|
|
|
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
|
|
|
}
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
|
|
cute::cp_async_fence();
|
|
|
|
|
|
- flash::gemm(
|
|
|
+ FLASH_NAMESPACE::gemm(
|
|
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
|
|
smem_thr_copy_Q, smem_thr_copy_K
|
|
|
);
|
|
|
if constexpr (Is_softcap){
|
|
|
- flash::apply_softcap(acc_s, params.softcap);
|
|
|
+ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
|
|
|
}
|
|
|
|
|
|
- flash::cp_async_wait<0>();
|
|
|
+ FLASH_NAMESPACE::cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
if (n_block > n_block_min) {
|
|
|
// Advance gK
|
|
@@ -972,7 +973,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
|
|
|
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
|
|
|
}
|
|
|
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
|
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
|
|
// isn't right and we get race conditions.
|
|
|
cute::cp_async_fence();
|
|
@@ -983,12 +984,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
);
|
|
|
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
|
|
|
|
|
- Tensor rP = flash::convert_type<Element>(acc_s);
|
|
|
+ Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
|
|
|
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
|
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
|
|
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
+ Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
|
|
|
|
- flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
+ FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
|
|
}
|
|
|
|
|
|
// Epilogue
|
|
@@ -1005,7 +1006,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
>;
|
|
|
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
|
|
|
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
|
|
- Tensor rO = flash::convert_type<ElementO>(acc_o);
|
|
|
+ Tensor rO = FLASH_NAMESPACE::convert_type<ElementO>(acc_o);
|
|
|
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
|
|
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
|
|
|
|
@@ -1064,7 +1065,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
|
|
}
|
|
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
|
|
- flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
|
|
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
|
|
);
|
|
|
}
|
|
@@ -1087,7 +1088,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
|
|
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
|
|
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
|
|
|
|
|
- flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
|
|
|
+ FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@@ -1101,7 +1102,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
|
|
|
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
|
|
|
const int n_split_idx = Split ? blockIdx.y : 0;
|
|
|
const int num_n_splits = Split ? gridDim.y : 1;
|
|
|
- flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
|
|
+ FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@@ -1242,7 +1243,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
}
|
|
|
// Load Oaccum in then scale and accumulate to O
|
|
|
for (int split = 0; split < params.num_splits; ++split) {
|
|
|
- flash::copy</*Is_even_MN=*/false, Is_even_K>(
|
|
|
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, Is_even_K>(
|
|
|
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
|
|
|
);
|
|
|
#pragma unroll
|
|
@@ -1262,7 +1263,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
}
|
|
|
// if (cute::thread0()) { print_tensor(tOrO); }
|
|
|
|
|
|
- Tensor rO = flash::convert_type<Element>(tOrO);
|
|
|
+ Tensor rO = FLASH_NAMESPACE::convert_type<Element>(tOrO);
|
|
|
// Write to gO
|
|
|
#pragma unroll
|
|
|
for (int m = 0; m < size<1>(rO); ++m) {
|
|
@@ -1290,4 +1291,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-} // namespace flash
|
|
|
+} // namespace FLASH_NAMESPACE
|