utils.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cute/atom/copy_atom.hpp>
  14. #include <cutlass/array.h>
  15. #include <cutlass/cutlass.h>
  16. #include <cutlass/numeric_conversion.h>
  17. #include <cutlass/numeric_types.h>
  18. #define CHECK_CUDA(call) \
  19. do { \
  20. cudaError_t status_ = call; \
  21. if (status_ != cudaSuccess) { \
  22. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  23. exit(1); \
  24. } \
  25. } while(0)
  26. #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
  27. namespace flash {
  28. using namespace cute;
  29. ////////////////////////////////////////////////////////////////////////////////////////////////////
  30. template<typename T>
  31. struct MaxOp {
  32. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  33. };
  34. template <>
  35. struct MaxOp<float> {
  36. // This is slightly faster
  37. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  38. };
  39. ////////////////////////////////////////////////////////////////////////////////////////////////////
  40. template<typename T>
  41. struct SumOp {
  42. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  43. };
  44. ////////////////////////////////////////////////////////////////////////////////////////////////////
  45. template<int THREADS>
  46. struct Allreduce {
  47. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  48. template<typename T, typename Operator>
  49. static __device__ __forceinline__ T run(T x, Operator &op) {
  50. constexpr int OFFSET = THREADS / 2;
  51. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  52. return Allreduce<OFFSET>::run(x, op);
  53. }
  54. };
  55. ////////////////////////////////////////////////////////////////////////////////////////////////////
  56. template<>
  57. struct Allreduce<2> {
  58. template<typename T, typename Operator>
  59. static __device__ __forceinline__ T run(T x, Operator &op) {
  60. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  61. return x;
  62. }
  63. };
  64. ////////////////////////////////////////////////////////////////////////////////////////////////////
  65. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  66. // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  67. template<typename Layout>
  68. __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
  69. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  70. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  71. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  72. static_assert(decltype(rank(acc_layout))::value == 3);
  73. auto l = acc_layout;
  74. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
  75. } else { // SM80
  76. static_assert(decltype(size<0>(acc_layout))::value == 4);
  77. static_assert(decltype(rank(acc_layout))::value == 3);
  78. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  79. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  80. }
  81. };
  82. ////////////////////////////////////////////////////////////////////////////////////////////////////
  83. // For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  84. template<typename Layout>
  85. __forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout acc_layout) {
  86. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  87. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  88. static_assert(decltype(rank(acc_layout))::value == 3);
  89. auto l = acc_layout;
  90. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  91. };
  92. ////////////////////////////////////////////////////////////////////////////////////////////////////
  93. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  94. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  95. // For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
  96. template<typename MMA_traits, typename Layout>
  97. __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
  98. using X = Underscore;
  99. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  100. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  101. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  102. static_assert(decltype(rank(acc_layout))::value == 3);
  103. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  104. auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
  105. return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
  106. } else { // SM80
  107. static_assert(decltype(size<0>(acc_layout))::value == 4);
  108. static_assert(decltype(rank(acc_layout))::value == 3);
  109. constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
  110. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  111. if constexpr (mma_shape_K == 8) {
  112. return acc_layout;
  113. } else {
  114. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  115. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  116. }
  117. }
  118. };
  119. ////////////////////////////////////////////////////////////////////////////////////////////////////
  120. template <typename To_type, typename Engine, typename Layout>
  121. __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
  122. using From_type = typename Engine::value_type;
  123. constexpr int numel = decltype(size(tensor))::value;
  124. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  125. // HACK: this requires tensor to be "contiguous"
  126. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  127. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  128. // Tensor out = make_tensor_like<To_type>(tensor);
  129. // cute::copy(make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()), out);
  130. // return out;
  131. }
  132. ////////////////////////////////////////////////////////////////////////////////////////////////////
  133. template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
  134. typename TiledMma>
  135. __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
  136. constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
  137. // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
  138. if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
  139. warpgroup_fence_operand(tCrC);
  140. if constexpr (arrive) {
  141. warpgroup_arrive();
  142. }
  143. if constexpr (zero_init) {
  144. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  145. // Unroll the K mode manually to set scale D to 1
  146. CUTLASS_PRAGMA_UNROLL
  147. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  148. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  149. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  150. }
  151. } else {
  152. // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
  153. // Unroll the K mode manually to set scale D to 1
  154. CUTLASS_PRAGMA_UNROLL
  155. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  156. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  157. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  158. }
  159. }
  160. if constexpr (commit) {
  161. warpgroup_commit_batch();
  162. }
  163. if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
  164. warpgroup_fence_operand(tCrC);
  165. if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
  166. }
  167. ////////////////////////////////////////////////////////////////////////////////////////////////////
  168. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  169. typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  170. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  171. __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
  172. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  173. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  174. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  175. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  176. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  177. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  178. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  179. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  180. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  181. #pragma unroll
  182. for (int m = 0; m < size<1>(S); ++m) {
  183. if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  184. #pragma unroll
  185. for (int k = 0; k < size<2>(S); ++k) {
  186. if (Is_even_K || predicate_K(k)) {
  187. cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
  188. } else if (Clear_OOB_K) {
  189. cute::clear(D(_, m, k));
  190. }
  191. }
  192. } else if (Clear_OOB_MN) {
  193. cute::clear(D(_, m, _));
  194. }
  195. }
  196. }
  197. ////////////////////////////////////////////////////////////////////////////////////////////////////
  198. //
  199. //
  200. // Need this register byte permute/shuffle to match register layout of
  201. // (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
  202. struct ReorgCFp8toAFp8 {
  203. int selectorEx0;
  204. int selectorEx1;
  205. int selectorEx4;
  206. int selectorEx5;
  207. int upper_map[4] = {0, 3, 1, 2};
  208. int lower_map[4] = {1, 2, 0, 3};
  209. CUTLASS_DEVICE ReorgCFp8toAFp8() {
  210. int laneId = cutlass::canonical_lane_idx();
  211. if (laneId % 4 == 0 || laneId % 4 == 3) {
  212. selectorEx0 = 0x3210;
  213. selectorEx1 = 0x7654;
  214. selectorEx4 = 0x5410;
  215. selectorEx5 = 0x7632;
  216. } else {
  217. selectorEx0 = 0x7654;
  218. selectorEx1 = 0x3210;
  219. selectorEx4 = 0x1054;
  220. selectorEx5 = 0x3276;
  221. }
  222. }
  223. template <typename Fragment> CUTLASS_DEVICE auto operator()(Fragment &accum) {
  224. using namespace cute;
  225. // First update `mi` to the max per-row
  226. //
  227. auto VT = shape<0>(accum); // number of vector elements per tile.
  228. auto MT = shape<1>(accum); // number of tiles along M.
  229. auto NT = shape<2>(accum); // number of tiles along N.
  230. auto data = accum.data();
  231. int n = 0;
  232. #pragma unroll
  233. for (int i = 0; i < MT; ++i) {
  234. // Traverse 2-rows + 2-cols (2x2) simultaneously.
  235. #pragma unroll
  236. for (int k = 0; k < NT * size<2>(VT) / 2; ++k) {
  237. auto upper = *reinterpret_cast<uint32_t *>(&data[n]);
  238. auto lower = *reinterpret_cast<uint32_t *>(&data[n + 4]);
  239. auto upper0 = __byte_perm(upper, lower, selectorEx0);
  240. auto lower0 = __byte_perm(upper, lower, selectorEx1);
  241. upper0 =
  242. __shfl_sync(uint32_t(-1), upper0, upper_map[threadIdx.x % 4], 4);
  243. lower0 =
  244. __shfl_sync(uint32_t(-1), lower0, lower_map[threadIdx.x % 4], 4);
  245. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  246. data_32bit[0] = __byte_perm(upper0, lower0, selectorEx4);
  247. data_32bit[1] = __byte_perm(upper0, lower0, selectorEx5);
  248. n += 8;
  249. }
  250. }
  251. }
  252. };
  253. // Reshape Utility for converting the layout from accumulator of GEMM-I
  254. // to Operand A of GEMM-II.
  255. struct ReshapeTStoTP {
  256. template <class FragmentC, class FragmentQ>
  257. CUTLASS_DEVICE auto operator()(FragmentC &&tC, FragmentQ &&tQ) {
  258. // get the layout of one row of Q.
  259. auto layoutQRow = make_layout_like(tQ(_, 0, _).layout());
  260. // get the layout of M dimension of C.
  261. auto layoutCM = get<1>(tC.layout());
  262. return make_layout(get<0>(layoutQRow), layoutCM, get<1>(layoutQRow));
  263. }
  264. };
  265. template <int NumCopyThreads, typename ElemO, typename TMACopyO, typename LayoutO,
  266. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  267. __forceinline__ __device__ void write_tma(
  268. ElemO* O, const TMACopyO& tma_store_O,
  269. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  270. const SMemO& sO, int m_block, int bidh, int bidb,
  271. const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
  272. Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape());
  273. Tensor gO = seqlen_traits_o.get_local_tile_tensor(
  274. mO, tile_shape_O, bidh, bidb
  275. )(_, _, m_block); // (M, K)
  276. auto block_tma_O = tma_store_O.get_slice(_0{});
  277. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  278. Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
  279. int const lane_predicate = cute::elect_one_sync();
  280. int const warp_idx = cutlass::canonical_warp_idx_sync();
  281. if (warp_idx == write_warp_idx && lane_predicate) {
  282. cute::copy(tma_store_O, tOsO, tOgO);
  283. tma_store_arrive();
  284. }
  285. // Note: no wait here.
  286. // tma_store_wait<0>();
  287. }
  288. template <int NumCopyThreads, typename ElemO, typename TiledCopyO, typename LayoutO,
  289. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  290. __forceinline__ __device__ void write_tiled(
  291. ElemO* O, const TiledCopyO& tiled_copy_O,
  292. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  293. const SMemO& sO, int m_block, int bidh, int bidb,
  294. const SeqLenTraits& seqlen_traits_o) {
  295. Tensor mO = make_tensor(make_gmem_ptr(O), layout_O);
  296. Tensor gO = seqlen_traits_o.get_local_tile_tensor(
  297. mO, tile_shape_O, bidh, bidb
  298. )(_, _, m_block); // (M, K)
  299. ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads);
  300. Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k)
  301. Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K)
  302. // Prepare for TiledCopy.
  303. // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst.
  304. // After grouping, the first dim is number of elements to read together.
  305. Tensor tOsOFlatten = cute::flatten(tOsO);
  306. Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten);
  307. Tensor tOgOFlatten = cute::flatten(tOgO);
  308. Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten);
  309. // Get thread coords to global index mapping.
  310. Tensor gOCounting = cute::make_identity_tensor(gO.shape());
  311. Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting);
  312. Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting);
  313. Tensor tSgOCountingGrouped =
  314. cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten);
  315. // Write out to GMEM.
  316. const int kNumMsPerTile = get<0>(tile_shape_O);
  317. int cta_m = std::min(
  318. seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile
  319. );
  320. if (cta_m == kNumMsPerTile) {
  321. copy(tiled_copy_O, tOsOGroup, tOgOGroup);
  322. } else {
  323. auto predicate_fn = [&](auto coords) {
  324. auto s_coords = tSgOCountingGrouped(_0{}, coords);
  325. return elem_less(get<0>(s_coords), cta_m);
  326. };
  327. copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
  328. }
  329. }
  330. template <bool IsTMACopy, int NumCopyThreads, typename ElemO,
  331. typename TMACopyO, typename TiledCopyO, typename LayoutO,
  332. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  333. __forceinline__ __device__ void write_O(
  334. ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O,
  335. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  336. const SMemO& sO, int m_block, int bidh, int bidb,
  337. const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
  338. if constexpr (IsTMACopy) {
  339. write_tma<NumCopyThreads>(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx);
  340. } else {
  341. write_tiled<NumCopyThreads>(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o);
  342. }
  343. }
  344. ////////////////////////////////////////////////////////////////////////////////////////////////////
  345. } // namespace flash