utils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
  14. #include <cutlass/array.h>
  15. #include <cutlass/cutlass.h>
  16. #include <cutlass/numeric_conversion.h>
  17. #include <cutlass/numeric_types.h>
  18. #define CHECK_CUDA(call) \
  19. do { \
  20. cudaError_t status_ = call; \
  21. if (status_ != cudaSuccess) { \
  22. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  23. exit(1); \
  24. } \
  25. } while(0)
  26. #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
  27. namespace flash {
  28. using namespace cute;
  29. ////////////////////////////////////////////////////////////////////////////////////////////////////
  30. template<typename T>
  31. struct MaxOp {
  32. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  33. };
  34. template <>
  35. struct MaxOp<float> {
  36. // This is slightly faster
  37. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  38. };
  39. ////////////////////////////////////////////////////////////////////////////////////////////////////
  40. template<typename T>
  41. struct SumOp {
  42. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  43. };
  44. ////////////////////////////////////////////////////////////////////////////////////////////////////
  45. template<int THREADS>
  46. struct Allreduce {
  47. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  48. template<typename T, typename Operator>
  49. static __device__ __forceinline__ T run(T x, Operator &op) {
  50. constexpr int OFFSET = THREADS / 2;
  51. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  52. return Allreduce<OFFSET>::run(x, op);
  53. }
  54. };
  55. ////////////////////////////////////////////////////////////////////////////////////////////////////
  56. template<>
  57. struct Allreduce<2> {
  58. template<typename T, typename Operator>
  59. static __device__ __forceinline__ T run(T x, Operator &op) {
  60. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  61. return x;
  62. }
  63. };
  64. ////////////////////////////////////////////////////////////////////////////////////////////////////
  65. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  66. // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  67. template<typename Layout>
  68. __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
  69. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  70. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  71. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  72. static_assert(decltype(rank(acc_layout))::value == 3);
  73. auto l = acc_layout;
  74. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
  75. } else { // SM80
  76. static_assert(decltype(size<0>(acc_layout))::value == 4);
  77. static_assert(decltype(rank(acc_layout))::value == 3);
  78. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  79. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  80. }
  81. };
  82. ////////////////////////////////////////////////////////////////////////////////////////////////////
  83. // For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  84. template<typename Layout>
  85. __forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout acc_layout) {
  86. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  87. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  88. static_assert(decltype(rank(acc_layout))::value == 3);
  89. auto l = acc_layout;
  90. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  91. };
  92. ////////////////////////////////////////////////////////////////////////////////////////////////////
  93. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  94. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  95. // For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
  96. template<typename MMA_traits, typename Layout>
  97. __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
  98. using X = Underscore;
  99. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  100. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  101. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  102. static_assert(decltype(rank(acc_layout))::value == 3);
  103. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  104. auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
  105. return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
  106. } else { // SM80
  107. static_assert(decltype(size<0>(acc_layout))::value == 4);
  108. static_assert(decltype(rank(acc_layout))::value == 3);
  109. constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
  110. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  111. if constexpr (mma_shape_K == 8) {
  112. return acc_layout;
  113. } else {
  114. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  115. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  116. }
  117. }
  118. };
  119. // Convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
  120. template<typename Layout>
  121. __forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(Layout acc_layout) {
  122. using X = Underscore;
  123. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  124. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  125. static_assert(decltype(rank(acc_layout))::value == 3);
  126. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  127. auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _4>{}); // (2, 2, (2, N / 32)))
  128. return make_layout(make_layout(Shape<_4, _2, _2>{}),
  129. get<1>(acc_layout),
  130. make_layout(get<2, 1>(l), get<2>(acc_layout)));
  131. };
  132. ////////////////////////////////////////////////////////////////////////////////////////////////////
  133. // Byte permute for fp8 kernel
  134. template <typename Fragment>
  135. CUTLASS_DEVICE void permute_regs_A_to_C(Fragment &accum) {
  136. auto data = accum.data();
  137. #pragma unroll
  138. for (int n = 0; n < size(accum); n += 8) {
  139. uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
  140. auto upper = data_32bit[0];
  141. auto lower = data_32bit[1];
  142. data_32bit[0] = __byte_perm(upper, lower, 0x5410);
  143. data_32bit[1] = __byte_perm(upper, lower, 0x7632);
  144. }
  145. }
  146. ////////////////////////////////////////////////////////////////////////////////////////////////////
  147. template <typename To_type, typename Engine, typename Layout>
  148. __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
  149. using From_type = typename Engine::value_type;
  150. constexpr int numel = decltype(size(tensor))::value;
  151. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  152. // HACK: this requires tensor to be "contiguous"
  153. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  154. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  155. // Tensor out = make_tensor_like<To_type>(tensor);
  156. // cute::copy(make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()), out);
  157. // return out;
  158. }
  159. ////////////////////////////////////////////////////////////////////////////////////////////////////
  160. template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
  161. typename TiledMma>
  162. __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
  163. constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
  164. // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
  165. if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
  166. warpgroup_fence_operand(tCrC);
  167. if constexpr (arrive) {
  168. warpgroup_arrive();
  169. }
  170. if constexpr (zero_init) {
  171. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  172. // Unroll the K mode manually to set scale D to 1
  173. CUTLASS_PRAGMA_UNROLL
  174. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  175. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  176. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  177. }
  178. } else {
  179. // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
  180. // Unroll the K mode manually to set scale D to 1
  181. CUTLASS_PRAGMA_UNROLL
  182. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  183. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  184. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  185. }
  186. }
  187. if constexpr (commit) {
  188. warpgroup_commit_batch();
  189. }
  190. if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
  191. warpgroup_fence_operand(tCrC);
  192. if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
  193. }
  194. ////////////////////////////////////////////////////////////////////////////////////////////////////
  195. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  196. typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  197. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  198. __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
  199. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  200. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  201. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  202. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  203. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  204. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  205. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  206. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  207. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  208. #pragma unroll
  209. for (int m = 0; m < size<1>(S); ++m) {
  210. if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  211. #pragma unroll
  212. for (int k = 0; k < size<2>(S); ++k) {
  213. if (Is_even_K || predicate_K(k)) {
  214. cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
  215. } else if (Clear_OOB_K) {
  216. cute::clear(D(_, m, k));
  217. }
  218. }
  219. } else if (Clear_OOB_MN) {
  220. cute::clear(D(_, m, _));
  221. }
  222. }
  223. }
  224. ////////////////////////////////////////////////////////////////////////////////////////////////////
  225. template <int NumCopyThreads, typename ElemO, typename TMACopyO, typename LayoutO,
  226. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  227. __forceinline__ __device__ void write_tma(
  228. ElemO* O, const TMACopyO& tma_store_O,
  229. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  230. const SMemO& sO, int m_block, int bidh, int bidb,
  231. const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
  232. Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape());
  233. Tensor gO = seqlen_traits_o.get_local_tile_tensor(
  234. mO, tile_shape_O, bidh, bidb
  235. )(_, _, m_block); // (M, K)
  236. auto block_tma_O = tma_store_O.get_slice(_0{});
  237. Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
  238. Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
  239. int const lane_predicate = cute::elect_one_sync();
  240. int const warp_idx = cutlass::canonical_warp_idx_sync();
  241. if (warp_idx == write_warp_idx && lane_predicate) {
  242. cute::copy(tma_store_O, tOsO, tOgO);
  243. tma_store_arrive();
  244. }
  245. // Note: no wait here.
  246. // tma_store_wait<0>();
  247. }
  248. template <int NumCopyThreads, typename ElemO, typename TiledCopyO, typename LayoutO,
  249. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  250. __forceinline__ __device__ void write_tiled(
  251. ElemO* O, const TiledCopyO& tiled_copy_O,
  252. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  253. const SMemO& sO, int m_block, int bidh, int bidb,
  254. const SeqLenTraits& seqlen_traits_o) {
  255. Tensor mO = make_tensor(make_gmem_ptr(O), layout_O);
  256. Tensor gO = seqlen_traits_o.get_local_tile_tensor(
  257. mO, tile_shape_O, bidh, bidb
  258. )(_, _, m_block); // (M, K)
  259. ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads);
  260. Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k)
  261. Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K)
  262. // Prepare for TiledCopy.
  263. // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst.
  264. // After grouping, the first dim is number of elements to read together.
  265. Tensor tOsOFlatten = cute::flatten(tOsO);
  266. Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten);
  267. Tensor tOgOFlatten = cute::flatten(tOgO);
  268. Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten);
  269. // Get thread coords to global index mapping.
  270. Tensor gOCounting = cute::make_identity_tensor(gO.shape());
  271. Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting);
  272. Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting);
  273. Tensor tSgOCountingGrouped =
  274. cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten);
  275. // Write out to GMEM.
  276. const int kNumMsPerTile = get<0>(tile_shape_O);
  277. int cta_m = std::min(
  278. seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile
  279. );
  280. if (cta_m == kNumMsPerTile) {
  281. copy(tiled_copy_O, tOsOGroup, tOgOGroup);
  282. } else {
  283. auto predicate_fn = [&](auto coords) {
  284. auto s_coords = tSgOCountingGrouped(_0{}, coords);
  285. return elem_less(get<0>(s_coords), cta_m);
  286. };
  287. copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
  288. }
  289. }
  290. template <bool IsTMACopy, int NumCopyThreads, typename ElemO,
  291. typename TMACopyO, typename TiledCopyO, typename LayoutO,
  292. typename TileShapeO, typename SMemO, typename SeqLenTraits>
  293. __forceinline__ __device__ void write_O(
  294. ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O,
  295. const LayoutO& layout_O, const TileShapeO& tile_shape_O,
  296. const SMemO& sO, int m_block, int bidh, int bidb,
  297. const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
  298. if constexpr (IsTMACopy) {
  299. write_tma<NumCopyThreads>(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx);
  300. } else {
  301. write_tiled<NumCopyThreads>(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o);
  302. }
  303. }
  304. ////////////////////////////////////////////////////////////////////////////////////////////////////
  305. } // namespace flash