copy_paged_sm90_tma_cutlass36.hpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. #pragma once
  2. #include <cute/arch/copy_sm90_tma.hpp>
  3. #include <cute/atom/copy_traits_sm90_tma.hpp>
  4. #include <cutlass/version.h>
  5. static_assert(CUTLASS_VERSION >= 360, "CUTLASS 3.6.x is required for this file due to incompatible API changes in Cutlass. Cutlass < 3.6 does not have the cache_hint argument to SM90_TMA_LOAD ops.");
  6. struct PagedCopyArgs {
  7. CUTE_HOST_DEVICE
  8. PagedCopyArgs() : block_table_batch_stride{0}, page_block_size(0), block_table(nullptr) {
  9. };
  10. CUTE_HOST_DEVICE
  11. PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, const int32_t *const block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) {
  12. };
  13. const int64_t block_table_batch_stride; // The stride between block tables for different batches
  14. const int page_block_size; // The size of a page block in number of elements
  15. const int32_t *const block_table; // The block table, must be properly sized or a nullptr
  16. };
  17. namespace cute {
  18. struct SM90_TMA_LOAD_PAGED
  19. {
  20. using COPY_OP = SM90_TMA_LOAD; // The underlying copy operation that we delegate work to
  21. CUTE_HOST_DEVICE static void
  22. copy(void const* desc_ptr, uint64_t* mbar_ptr,
  23. void * smem_ptr,
  24. int32_t const& crd0)
  25. {
  26. CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 1D");
  27. }
  28. CUTE_HOST_DEVICE static void
  29. copy(void const* desc_ptr, uint64_t* mbar_ptr,
  30. PagedCopyArgs const* pca,
  31. void * smem_ptr,
  32. int32_t const& crd0, int32_t const& crd1)
  33. {
  34. CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 2D");
  35. }
  36. CUTE_HOST_DEVICE static void
  37. copy(void const* desc_ptr, uint64_t* mbar_ptr,
  38. PagedCopyArgs const* pca,
  39. void * smem_ptr,
  40. int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
  41. {
  42. // WARNING: Do not place anything else here, or a performance regression will occur
  43. // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
  44. // asserts that pca==nullptr, but even an assert would kill performance
  45. return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2);
  46. }
  47. CUTE_HOST_DEVICE static void
  48. copy(void const* desc_ptr, uint64_t* mbar_ptr,
  49. PagedCopyArgs const* pca,
  50. void * smem_ptr,
  51. // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
  52. // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
  53. // and detail::make_tma_copy_desc to create a TMA descriptor.
  54. // The same reordering is aplied prior to calling via cute::tma_partition.
  55. // Final order determined experimentally.
  56. int32_t const& crdK, // embedding dim
  57. int32_t const& crdM, // sequence dim
  58. int32_t const& crdH, // head dim
  59. int32_t const& crdB) // batch dim
  60. {
  61. //auto log = pca.debug_log->nextline();
  62. //log.append_threadinfo();
  63. //log.snprintf("SM_90_TMA_LOAD_PAGED::copy(%d, %d, %d, %d) ", (int)crdM, (int)crdK, (int)crdH, (int)crdB);
  64. if (pca == nullptr) {
  65. return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB);
  66. }
  67. auto const page_block_size = pca->page_block_size;
  68. int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
  69. int32_t const seq_pos_offset = crdM - page_idx_offset * page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
  70. int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
  71. //if (cute::thread0()) {
  72. // printf("SM90_TMA_LOAD_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
  73. //}
  74. return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx);
  75. }
  76. CUTE_HOST_DEVICE static void
  77. copy(void const* desc_ptr, uint64_t* mbar_ptr,
  78. void * smem_ptr,
  79. int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
  80. {
  81. CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 5D");
  82. }
  83. };
  84. struct SM90_TMA_LOAD_MULTICAST_PAGED
  85. {
  86. CUTE_HOST_DEVICE static void
  87. copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
  88. void * smem_ptr,
  89. int32_t const& crd0)
  90. {
  91. CUTE_INVALID_CONTROL_PATH("not implemented");
  92. }
  93. CUTE_HOST_DEVICE static void
  94. copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
  95. PagedCopyArgs const* pca,
  96. void * smem_ptr,
  97. int32_t const& crd0, int32_t const& crd1)
  98. {
  99. CUTE_INVALID_CONTROL_PATH("not implemented");
  100. }
  101. CUTE_HOST_DEVICE static void
  102. copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
  103. PagedCopyArgs const* pca,
  104. void * smem_ptr,
  105. int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
  106. {
  107. // WARNING: Do not place anything else here, or a performance regression will occur
  108. // look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
  109. // asserts that pca==nullptr, but even an assert would kill performance
  110. return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crd0, crd1, crd2);
  111. }
  112. CUTE_HOST_DEVICE static void
  113. copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
  114. PagedCopyArgs const* pca,
  115. void * smem_ptr,
  116. // Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
  117. // via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
  118. // and detail::make_tma_copy_desc to create a TMA descriptor.
  119. // The same reordering is aplied prior to calling via cute::tma_partition.
  120. // Final order determined experimentally.
  121. int32_t const& crdK, // embedding dim
  122. int32_t const& crdM, // sequence dim
  123. int32_t const& crdH, // head dim
  124. int32_t const& crdB) // batch dim
  125. {
  126. if (pca == nullptr) {
  127. return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, crdM, crdH, crdB);
  128. }
  129. auto const page_block_size = pca->page_block_size;
  130. int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
  131. int32_t const seq_pos_offset = crdM - page_idx_offset*page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
  132. int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
  133. //if (cute::thread0()) {
  134. // printf("SM90_TMA_LOAD_MULTICAST_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
  135. //}
  136. return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_NORMAL), smem_ptr, crdK, seq_pos_offset, crdH, page_idx);
  137. }
  138. };
  139. // We also need to specialize Copy_Traits for PAGED_COPY_OP, we can do this by inheriting from the traits of the underlying copy op
  140. //////////////////////////////////////////////////////////////////////////////
  141. ///////////////////////////// TMA_LOAD ///////////////////////////////////////
  142. //////////////////////////////////////////////////////////////////////////////
  143. struct SM90_TMA_LOAD_PAGED_OP : SM90_TMA_LOAD_PAGED {};
  144. // The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
  145. // Use .with(tma_mbar) to construct an executable version
  146. template <class NumBitsPerTMA, class AuxParams_>
  147. struct Copy_Traits<SM90_TMA_LOAD_PAGED, NumBitsPerTMA, AuxParams_>
  148. {
  149. using ThrID = Layout<_1>;
  150. // Map from (src-thr,src-val) to bit
  151. using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  152. // Map from (dst-thr,dst-val) to bit
  153. using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  154. // Reference map from (thr,val) to bit
  155. using RefLayout = SrcLayout;
  156. // SM90_TMA_LOAD arguments
  157. TmaDescriptor tma_desc_;
  158. using AuxParams = AuxParams_;
  159. AuxParams aux_params_;
  160. // Return TmaDescriptor/TensorMap
  161. CUTE_HOST_DEVICE constexpr
  162. TmaDescriptor const*
  163. get_tma_descriptor() const {
  164. return &tma_desc_;
  165. }
  166. // Construct an executable SM90_TMA_LOAD with tma_mbar
  167. CUTE_HOST_DEVICE constexpr
  168. Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
  169. with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  170. // We accept multicast_mask here to keep the API for both atoms consistent
  171. return {{}, {&tma_desc_, &tma_mbar, nullptr}};
  172. }
  173. // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
  174. CUTE_HOST_DEVICE constexpr
  175. Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
  176. with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  177. // We accept multicast_mask here to keep the API for both atoms consistent
  178. return {{}, {new_tma_desc, &tma_mbar, nullptr }};
  179. }
  180. CUTE_HOST_DEVICE constexpr
  181. Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
  182. with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  183. // We accept multicast_mask here to keep the API for both atoms consistent
  184. return {{}, {&tma_desc_, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
  185. }
  186. // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
  187. CUTE_HOST_DEVICE constexpr
  188. Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
  189. with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  190. // We accept multicast_mask here to keep the API for both atoms consistent
  191. return {{}, {new_tma_desc, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
  192. }
  193. // Generate the TMA coord tensor
  194. template <class GShape>
  195. CUTE_HOST_DEVICE constexpr
  196. auto
  197. get_tma_tensor(GShape const& g_shape) const {
  198. static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
  199. return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
  200. }
  201. // Don't try to execute a copy with SM90_TMA_LOAD before calling .with()
  202. template <class TS, class SLayout,
  203. class TD, class DLayout>
  204. CUTE_HOST_DEVICE friend constexpr void
  205. copy_unpack(Copy_Traits const& traits,
  206. Tensor<TS,SLayout> const& src,
  207. Tensor<TD,DLayout> & dst) = delete;
  208. };
  209. // The executable SM90_TMA_LOAD with tma_desc and tma_mbar
  210. template <class NumBitsPerTMA>
  211. struct Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
  212. : TMA_LOAD_Unpack<SM90_TMA_LOAD_PAGED_OP>
  213. {
  214. using ThrID = Layout<_1>;
  215. // Map from (src-thr,src-val) to bit
  216. using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  217. // Map from (dst-thr,dst-val) to bit
  218. using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  219. // Reference map from (thr,val) to bit
  220. using RefLayout = SrcLayout;
  221. // SM90_TMA_LOAD arguments
  222. tuple<
  223. TmaDescriptor const*,
  224. uint64_t*, // smem mbarrier
  225. PagedCopyArgs const*
  226. > const opargs_;
  227. };
  228. //////////////////////////////////////////////////////////////////////////////
  229. ///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
  230. //////////////////////////////////////////////////////////////////////////////
  231. struct SM90_TMA_LOAD_MULTICAST_PAGED_OP : SM90_TMA_LOAD_MULTICAST_PAGED {};
  232. // The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar
  233. // Use .with(tma_mbar, multicast_mask) to construct an executable version
  234. template <class NumBitsPerTMA, class AuxParams_>
  235. struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED, NumBitsPerTMA, AuxParams_>
  236. {
  237. using ThrID = Layout<_1>;
  238. // Map from (src-thr,src-val) to bit
  239. using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  240. // Map from (dst-thr,dst-val) to bit
  241. using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  242. // Reference map from (thr,val) to bit
  243. using RefLayout = SrcLayout;
  244. // SM90_TMA_LOAD_MULTICAST arguments
  245. TmaDescriptor tma_desc_;
  246. using AuxParams = AuxParams_;
  247. AuxParams aux_params_;
  248. // Return TmaDescriptor/TensorMap
  249. CUTE_HOST_DEVICE constexpr
  250. TmaDescriptor const*
  251. get_tma_descriptor() const {
  252. return &tma_desc_;
  253. }
  254. // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
  255. CUTE_HOST_DEVICE constexpr
  256. Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
  257. with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  258. return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, nullptr }};
  259. }
  260. // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
  261. CUTE_HOST_DEVICE constexpr
  262. Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
  263. with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  264. return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, nullptr }};
  265. }
  266. // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
  267. CUTE_HOST_DEVICE constexpr
  268. Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
  269. with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  270. return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
  271. }
  272. // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
  273. CUTE_HOST_DEVICE constexpr
  274. Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
  275. with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
  276. return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
  277. }
  278. // Generate the TMA coord tensor
  279. template <class GShape>
  280. CUTE_HOST_DEVICE constexpr
  281. auto
  282. get_tma_tensor(GShape const& g_shape) const {
  283. static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
  284. return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
  285. }
  286. // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with()
  287. template <class TS, class SLayout,
  288. class TD, class DLayout>
  289. CUTE_HOST_DEVICE friend constexpr void
  290. copy_unpack(Copy_Traits const& traits,
  291. Tensor<TS,SLayout> const& src,
  292. Tensor<TD,DLayout> & dst) = delete;
  293. };
  294. // The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask
  295. template <class NumBitsPerTMA>
  296. struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
  297. : TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_PAGED_OP>
  298. {
  299. using ThrID = Layout<_1>;
  300. // Map from (src-thr,src-val) to bit
  301. using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  302. // Map from (dst-thr,dst-val) to bit
  303. using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
  304. // Reference map from (thr,val) to bit
  305. using RefLayout = SrcLayout;
  306. // SM90_TMA_LOAD_MULTICAST arguments
  307. tuple<
  308. TmaDescriptor const*,
  309. uint64_t*, // smem mbarrier
  310. uint16_t, // multicast mask
  311. PagedCopyArgs const*
  312. > const opargs_;
  313. };
  314. template <class TmaInternalType = void,
  315. class CopyOp,
  316. class GEngine, class GLayout,
  317. class VShape,
  318. class SLayout,
  319. class CTA_Tiler,
  320. class Cluster_Size>
  321. CUTE_HOST_RTC
  322. auto
  323. make_virtualized_tma_copy(CopyOp const& copy_op,
  324. Tensor<GEngine,GLayout> const& gtensor,
  325. VShape const &virtual_shape,
  326. SLayout const slayout,
  327. CTA_Tiler const& cta_tiler,
  328. Cluster_Size const& cluster_size)
  329. {
  330. /**
  331. Variant of cute::make_tma_copy which allows to separate a virtual tensor coordinate space and
  332. a physical TMA tensor coordinate space. Used for Paged Attention with TMA.
  333. */
  334. auto cta_v_tile = make_identity_layout(virtual_shape).compose(cta_tiler);
  335. auto cta_t_tile = make_layout(cluster_size);
  336. //cute::print("\nVirtual Shape:"); cute::print(virtual_shape);
  337. //cute::print("\nPhysical Shape:"); cute::print(gtensor.layout().shape()); cute::print("\n");
  338. // Prefer TmaInternalType if specified. Fallback to GEngine::value_type
  339. using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
  340. return detail::make_tma_copy_tiled<TmaType>(copy_op,
  341. gtensor, slayout,
  342. cta_t_tile, cta_v_tile);
  343. }
  344. }