123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- #pragma once
- #include <cute/arch/copy_sm90_tma.hpp>
- #include <cute/atom/copy_traits_sm90_tma.hpp>
- #include <cutlass/version.h>
- static_assert(CUTLASS_VERSION >= 360, "CUTLASS 3.6.x is required for this file due to incompatible API changes in Cutlass. Cutlass < 3.6 does not have the cache_hint argument to SM90_TMA_LOAD ops.");
- struct PagedCopyArgs {
- CUTE_HOST_DEVICE
- PagedCopyArgs() : block_table_batch_stride{0}, page_block_size(0), block_table(nullptr) {
- };
- CUTE_HOST_DEVICE
- PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, const int32_t *const block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) {
- };
- const int64_t block_table_batch_stride; // The stride between block tables for different batches
- const int page_block_size; // The size of a page block in number of elements
- const int32_t *const block_table; // The block table, must be properly sized or a nullptr
- };
- namespace cute {
- struct SM90_TMA_LOAD_PAGED
- {
- using COPY_OP = SM90_TMA_LOAD; // The underlying copy operation that we delegate work to
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr,
- void * smem_ptr,
- int32_t const& crd0)
- {
- CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 1D");
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- int32_t const& crd0, int32_t const& crd1)
- {
- CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 2D");
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
- {
- // WARNING: Do not place anything else here, or a performance regression will occur
- // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
- // asserts that pca==nullptr, but even an assert would kill performance
- return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2);
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
- // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
- // and detail::make_tma_copy_desc to create a TMA descriptor.
- // The same reordering is aplied prior to calling via cute::tma_partition.
- // Final order determined experimentally.
- int32_t const& crdK, // embedding dim
- int32_t const& crdM, // sequence dim
- int32_t const& crdH, // head dim
- int32_t const& crdB) // batch dim
- {
- //auto log = pca.debug_log->nextline();
- //log.append_threadinfo();
- //log.snprintf("SM_90_TMA_LOAD_PAGED::copy(%d, %d, %d, %d) ", (int)crdM, (int)crdK, (int)crdH, (int)crdB);
- if (pca == nullptr) {
- return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB);
- }
- auto const page_block_size = pca->page_block_size;
- int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
- int32_t const seq_pos_offset = crdM - page_idx_offset * page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
- int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
- //if (cute::thread0()) {
- // printf("SM90_TMA_LOAD_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
- //}
-
- return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx);
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr,
- void * smem_ptr,
- int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
- {
- CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 5D");
- }
- };
- struct SM90_TMA_LOAD_MULTICAST_PAGED
- {
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
- void * smem_ptr,
- int32_t const& crd0)
- {
- CUTE_INVALID_CONTROL_PATH("not implemented");
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- int32_t const& crd0, int32_t const& crd1)
- {
- CUTE_INVALID_CONTROL_PATH("not implemented");
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
- {
- // WARNING: Do not place anything else here, or a performance regression will occur
- // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
- // asserts that pca==nullptr, but even an assert would kill performance
- return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2);
- }
- CUTE_HOST_DEVICE static void
- copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
- PagedCopyArgs const* pca,
- void * smem_ptr,
- // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
- // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
- // and detail::make_tma_copy_desc to create a TMA descriptor.
- // The same reordering is aplied prior to calling via cute::tma_partition.
- // Final order determined experimentally.
- int32_t const& crdK, // embedding dim
- int32_t const& crdM, // sequence dim
- int32_t const& crdH, // head dim
- int32_t const& crdB) // batch dim
- {
- if (pca == nullptr) {
- return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB);
- }
- auto const page_block_size = pca->page_block_size;
- int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
- int32_t const seq_pos_offset = crdM - page_idx_offset*page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
- int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
- //if (cute::thread0()) {
- // printf("SM90_TMA_LOAD_MULTICAST_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
- //}
- return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx);
-
- }
- };
- // We also need to specialize Copy_Traits for PAGED_COPY_OP, we can do this by inheriting from the traits of the underlying copy op
- //////////////////////////////////////////////////////////////////////////////
- ///////////////////////////// TMA_LOAD ///////////////////////////////////////
- //////////////////////////////////////////////////////////////////////////////
- struct SM90_TMA_LOAD_PAGED_OP : SM90_TMA_LOAD_PAGED {};
- // The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
- // Use .with(tma_mbar) to construct an executable version
- template <class NumBitsPerTMA, class AuxParams_>
- struct Copy_Traits<SM90_TMA_LOAD_PAGED, NumBitsPerTMA, AuxParams_>
- {
- using ThrID = Layout<_1>;
- // Map from (src-thr,src-val) to bit
- using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Map from (dst-thr,dst-val) to bit
- using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Reference map from (thr,val) to bit
- using RefLayout = SrcLayout;
- // SM90_TMA_LOAD arguments
- TmaDescriptor tma_desc_;
- using AuxParams = AuxParams_;
- AuxParams aux_params_;
- // Return TmaDescriptor/TensorMap
- CUTE_HOST_DEVICE constexpr
- TmaDescriptor const*
- get_tma_descriptor() const {
- return &tma_desc_;
- }
- // Construct an executable SM90_TMA_LOAD with tma_mbar
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
- with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- // We accept multicast_mask here to keep the API for both atoms consistent
- return {{}, {&tma_desc_, &tma_mbar, nullptr}};
- }
- // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
- with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- // We accept multicast_mask here to keep the API for both atoms consistent
- return {{}, {new_tma_desc, &tma_mbar, nullptr }};
- }
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
- with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- // We accept multicast_mask here to keep the API for both atoms consistent
- return {{}, {&tma_desc_, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
- }
- // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
- with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- // We accept multicast_mask here to keep the API for both atoms consistent
- return {{}, {new_tma_desc, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
- }
- // Generate the TMA coord tensor
- template <class GShape>
- CUTE_HOST_DEVICE constexpr
- auto
- get_tma_tensor(GShape const& g_shape) const {
- static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
- return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
- }
- // Don't try to execute a copy with SM90_TMA_LOAD before calling .with()
- template <class TS, class SLayout,
- class TD, class DLayout>
- CUTE_HOST_DEVICE friend constexpr void
- copy_unpack(Copy_Traits const& traits,
- Tensor<TS,SLayout> const& src,
- Tensor<TD,DLayout> & dst) = delete;
- };
- // The executable SM90_TMA_LOAD with tma_desc and tma_mbar
- template <class NumBitsPerTMA>
- struct Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
- : TMA_LOAD_Unpack<SM90_TMA_LOAD_PAGED_OP>
- {
- using ThrID = Layout<_1>;
- // Map from (src-thr,src-val) to bit
- using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Map from (dst-thr,dst-val) to bit
- using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Reference map from (thr,val) to bit
- using RefLayout = SrcLayout;
- // SM90_TMA_LOAD arguments
- tuple<
- TmaDescriptor const*,
- uint64_t*, // smem mbarrier
- PagedCopyArgs const*
- > const opargs_;
- };
- //////////////////////////////////////////////////////////////////////////////
- ///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
- //////////////////////////////////////////////////////////////////////////////
- struct SM90_TMA_LOAD_MULTICAST_PAGED_OP : SM90_TMA_LOAD_MULTICAST_PAGED {};
- // The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar
- // Use .with(tma_mbar, multicast_mask) to construct an executable version
- template <class NumBitsPerTMA, class AuxParams_>
- struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED, NumBitsPerTMA, AuxParams_>
- {
- using ThrID = Layout<_1>;
- // Map from (src-thr,src-val) to bit
- using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Map from (dst-thr,dst-val) to bit
- using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Reference map from (thr,val) to bit
- using RefLayout = SrcLayout;
- // SM90_TMA_LOAD_MULTICAST arguments
- TmaDescriptor tma_desc_;
- using AuxParams = AuxParams_;
- AuxParams aux_params_;
- // Return TmaDescriptor/TensorMap
- CUTE_HOST_DEVICE constexpr
- TmaDescriptor const*
- get_tma_descriptor() const {
- return &tma_desc_;
- }
- // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
- with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, nullptr }};
- }
- // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
- with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, nullptr }};
- }
- // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
- with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
- }
- // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
- CUTE_HOST_DEVICE constexpr
- Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
- with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
- return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
- }
- // Generate the TMA coord tensor
- template <class GShape>
- CUTE_HOST_DEVICE constexpr
- auto
- get_tma_tensor(GShape const& g_shape) const {
- static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
- return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
- }
- // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with()
- template <class TS, class SLayout,
- class TD, class DLayout>
- CUTE_HOST_DEVICE friend constexpr void
- copy_unpack(Copy_Traits const& traits,
- Tensor<TS,SLayout> const& src,
- Tensor<TD,DLayout> & dst) = delete;
- };
- // The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask
- template <class NumBitsPerTMA>
- struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
- : TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_PAGED_OP>
- {
- using ThrID = Layout<_1>;
- // Map from (src-thr,src-val) to bit
- using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Map from (dst-thr,dst-val) to bit
- using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
- // Reference map from (thr,val) to bit
- using RefLayout = SrcLayout;
- // SM90_TMA_LOAD_MULTICAST arguments
- tuple<
- TmaDescriptor const*,
- uint64_t*, // smem mbarrier
- uint16_t, // multicast mask
- PagedCopyArgs const*
- > const opargs_;
- };
- template <class TmaInternalType = void,
- class CopyOp,
- class GEngine, class GLayout,
- class VShape,
- class SLayout,
- class CTA_Tiler,
- class Cluster_Size>
- CUTE_HOST_RTC
- auto
- make_virtualized_tma_copy(CopyOp const& copy_op,
- Tensor<GEngine,GLayout> const& gtensor,
- VShape const &virtual_shape,
- SLayout const slayout,
- CTA_Tiler const& cta_tiler,
- Cluster_Size const& cluster_size)
- {
- /**
- Variant of cute::make_tma_copy which allows to separate a virtual tensor coordinate space and
- a physical TMA tensor coordinate space. Used for Paged Attention with TMA.
- */
- auto cta_v_tile = make_identity_layout(virtual_shape).compose(cta_tiler);
- auto cta_t_tile = make_layout(cluster_size);
- //cute::print("\nVirtual Shape:"); cute::print(virtual_shape);
- //cute::print("\nPhysical Shape:"); cute::print(gtensor.layout().shape()); cute::print("\n");
- // Prefer TmaInternalType if specified. Fallback to GEngine::value_type
- using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
- return detail::make_tma_copy_tiled<TmaType>(copy_op,
- gtensor, slayout,
- cta_t_tile, cta_v_tile);
- }
- }
|