utils.h 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cutlass/cutlass.h>
  14. #include <cutlass/array.h>
  15. #include <cutlass/numeric_conversion.h>
  16. #include <cutlass/numeric_types.h>
  17. #define CHECK_CUDA(call) \
  18. do { \
  19. cudaError_t status_ = call; \
  20. if (status_ != cudaSuccess) { \
  21. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  22. exit(1); \
  23. } \
  24. } while(0)
  25. #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
  26. namespace flash {
  27. using namespace cute;
  28. ////////////////////////////////////////////////////////////////////////////////////////////////////
  29. // A wrapper for the kernel that is used to guard against compilation on
  30. // architectures that will never use the kernel. The purpose of this is to
  31. // reduce the size of the compiled binary.
  32. // Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55
  33. template <typename Kernel>
  34. struct enable_sm90_or_later : Kernel {
  35. template <typename... Args>
  36. CUTLASS_DEVICE void operator()(Args&&... args) {
  37. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
  38. Kernel::operator()(std::forward<Args>(args)...);
  39. #endif
  40. }
  41. };
  42. template <typename Kernel>
  43. struct enable_sm80_to_sm89 : Kernel {
  44. template <typename... Args>
  45. CUTLASS_DEVICE void operator()(Args&&... args) {
  46. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890)
  47. Kernel::operator()(std::forward<Args>(args)...);
  48. #endif
  49. }
  50. };
  51. ////////////////////////////////////////////////////////////////////////////////////////////////////
  52. template<typename T>
  53. struct MaxOp {
  54. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  55. };
  56. template <>
  57. struct MaxOp<float> {
  58. // This is slightly faster
  59. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  60. };
  61. ////////////////////////////////////////////////////////////////////////////////////////////////////
  62. template<typename T>
  63. struct SumOp {
  64. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  65. };
  66. ////////////////////////////////////////////////////////////////////////////////////////////////////
  67. template<int THREADS>
  68. struct Allreduce {
  69. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  70. template<typename T, typename Operator>
  71. static __device__ __forceinline__ T run(T x, Operator &op) {
  72. constexpr int OFFSET = THREADS / 2;
  73. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  74. return Allreduce<OFFSET>::run(x, op);
  75. }
  76. };
  77. ////////////////////////////////////////////////////////////////////////////////////////////////////
  78. template<>
  79. struct Allreduce<2> {
  80. template<typename T, typename Operator>
  81. static __device__ __forceinline__ T run(T x, Operator &op) {
  82. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  83. return x;
  84. }
  85. };
  86. ////////////////////////////////////////////////////////////////////////////////////////////////////
  87. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  88. // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  89. template<bool Transposed=false, typename Layout0>
  90. CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) {
  91. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  92. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  93. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  94. static_assert(decltype(rank(acc_layout))::value == 3);
  95. auto l = acc_layout;
  96. if constexpr (!Transposed) {
  97. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
  98. } else {
  99. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  100. }
  101. } else { // SM80
  102. static_assert(decltype(size<0>(acc_layout))::value == 4);
  103. static_assert(decltype(rank(acc_layout))::value == 3);
  104. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  105. if constexpr (!Transposed) {
  106. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  107. } else {
  108. return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  109. }
  110. }
  111. };
  112. ////////////////////////////////////////////////////////////////////////////////////////////////////
  113. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  114. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  115. // For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
  116. // For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
  117. template<typename MMA_Traits, typename Layout0>
  118. CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) {
  119. using X = Underscore;
  120. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  121. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  122. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  123. static_assert(decltype(rank(acc_layout))::value == 3);
  124. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  125. if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
  126. auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
  127. return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  128. } else {
  129. static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
  130. static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
  131. static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
  132. auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
  133. // This combines the first two modes (<0, 0> and <0, 1>) into one mode.
  134. // Will require register shuffling later to be correct.
  135. return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
  136. get<1>(acc_layout),
  137. coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
  138. // This combination is right but doesn't work with register shuffling.
  139. // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
  140. // get<1>(acc_layout),
  141. // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  142. }
  143. } else { // SM80
  144. static_assert(decltype(size<0>(acc_layout))::value == 4);
  145. static_assert(decltype(rank(acc_layout))::value == 3);
  146. constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
  147. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  148. if constexpr (mma_shape_K == 8) {
  149. return acc_layout;
  150. } else {
  151. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  152. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  153. }
  154. }
  155. };
  156. ////////////////////////////////////////////////////////////////////////////////////////////////////
  157. template <typename To_type, typename Engine, typename Layout>
  158. CUTLASS_DEVICE auto convert_type_unsafe(Tensor<Engine, Layout> const &tensor) {
  159. using From_type = typename Engine::value_type;
  160. static constexpr int numel = decltype(size(tensor))::value;
  161. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  162. // HACK: this requires tensor to be "contiguous"
  163. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  164. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  165. // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not
  166. // inline this function, then the memory might not be valid.
  167. }
  168. ////////////////////////////////////////////////////////////////////////////////////////////////////
  169. template <typename Engine, typename Layout, typename EngineOut>
  170. CUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {
  171. // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.
  172. using From_type = typename Engine::value_type;
  173. using To_type = typename EngineOut::value_type;
  174. static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));
  175. static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
  176. Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);
  177. Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
  178. static_assert(size(frag) == size(out_frg));
  179. cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
  180. #pragma unroll
  181. for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }
  182. }
  183. ////////////////////////////////////////////////////////////////////////////////////////////////////
  184. // Blocks until all but N previous cp.async.commit_group operations have committed.
  185. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
  186. // (which is equivalent to commit_group then wait_group 0).
  187. // Instead we just call cp.async.wait_group 0, which is slightly faster.
  188. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
  189. template <int N>
  190. CUTE_HOST_DEVICE
  191. void cp_async_wait() {
  192. #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
  193. asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
  194. #endif
  195. }
  196. ////////////////////////////////////////////////////////////////////////////////////////////////////
  197. template <bool A, class Mma, class Tensor0>
  198. CUTLASS_DEVICE
  199. auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
  200. if constexpr (A) {
  201. return mma.partition_fragment_A(tensor0);
  202. } else {
  203. return mma.partition_fragment_B(tensor0);
  204. }
  205. }
  206. ////////////////////////////////////////////////////////////////////////////////////////////////////
  207. template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
  208. typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
  209. CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
  210. if constexpr (M_slice >= 0) {
  211. static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
  212. static_assert(M_slice < MMA_M);
  213. // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N)
  214. Tensor tCrC_slice = cute::logical_divide(tCrC, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  215. if constexpr (!SwapAB) {
  216. Tensor tCrA_slice = cute::logical_divide(tCrA, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  217. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA_slice, tCrB, tCrC_slice);
  218. } else {
  219. Tensor tCrB_slice = cute::logical_divide(tCrB, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  220. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA, tCrB_slice, tCrC_slice);
  221. }
  222. } else {
  223. constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
  224. // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
  225. if constexpr (Is_RS) {
  226. if constexpr (!SwapAB) {
  227. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  228. } else {
  229. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  230. }
  231. }
  232. warpgroup_fence_operand(tCrC);
  233. warpgroup_arrive();
  234. if constexpr (zero_init) {
  235. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  236. }
  237. static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA));
  238. static constexpr int kMaxKIters = 16;
  239. // Unroll the K mode manually to set scale D to 1
  240. CUTLASS_PRAGMA_UNROLL
  241. for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) {
  242. if constexpr (!SwapAB) {
  243. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  244. } else {
  245. cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC);
  246. }
  247. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  248. }
  249. // In the case of large kNumKIters, the compiler chooses to store the smem addresses
  250. // in registers, causing spills. This loop forces the compiler to recompute the addresses.
  251. if constexpr (kNumKIters > kMaxKIters) {
  252. // This will always be zero, just a way to force the compiler to recompute the smem
  253. // addresses. This results in USEL instructions. There's probably a better way to do this.
  254. int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1;
  255. CUTLASS_PRAGMA_UNROLL
  256. for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) {
  257. if constexpr (!SwapAB) {
  258. cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC);
  259. } else {
  260. cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC);
  261. }
  262. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  263. }
  264. }
  265. warpgroup_commit_batch();
  266. if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
  267. warpgroup_fence_operand(tCrC);
  268. if constexpr (Is_RS) {
  269. if constexpr (!SwapAB) {
  270. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  271. } else {
  272. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  273. }
  274. }
  275. }
  276. }
  277. ////////////////////////////////////////////////////////////////////////////////////////////////////
  278. template<bool A_in_regs=false, bool B_in_regs=false, bool SwapAB=false,
  279. typename Tensor0, typename Tensor1,
  280. typename Tensor2, typename Tensor3, typename Tensor4,
  281. typename TiledMma, typename TiledCopyA, typename TiledCopyB,
  282. typename ThrCopyA, typename ThrCopyB, typename Hook>
  283. CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
  284. Tensor4 const& tCsB, TiledMma tiled_mma,
  285. TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
  286. ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) {
  287. if constexpr (SwapAB) {
  288. gemm_sm80<B_in_regs, A_in_regs>(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn);
  289. } else {
  290. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  291. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  292. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  293. Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
  294. CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
  295. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  296. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  297. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
  298. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
  299. #pragma unroll
  300. for (int i = 0; i < size<2>(tCrA); ++i) {
  301. if (i < size<2>(tCrA) - 1) {
  302. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
  303. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
  304. }
  305. if constexpr (!std::is_same_v<Hook, std::nullptr_t>) {
  306. if (i == 0) { fn(); }
  307. }
  308. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  309. }
  310. }
  311. }
  312. ////////////////////////////////////////////////////////////////////////////////////////////////////
  313. template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
  314. typename TiledMma, typename TiledCopy, typename ThrCopy>
  315. CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
  316. TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
  317. ThrCopy smem_thr_copy_B) {
  318. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  319. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  320. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  321. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  322. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  323. cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
  324. #pragma unroll
  325. for (int i = 0; i < size<2>(tCrA); ++i) {
  326. if (i < size<2>(tCrA) - 1) {
  327. cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
  328. }
  329. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  330. }
  331. }
  332. ////////////////////////////////////////////////////////////////////////////////////////////////////
  333. template <bool zero_init=false, typename Atom, typename TA, typename TB, typename TC>
  334. CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
  335. static constexpr int rA = decltype(rank(tA))::value;
  336. static constexpr int rB = decltype(rank(tB))::value;
  337. static constexpr int rC = decltype(rank(tC))::value;
  338. static_assert(rA == 3 && rB == 3 && rC == 3);
  339. if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; }
  340. CUTLASS_PRAGMA_UNROLL
  341. for (int k_block = 0; k_block < size<2>(tA); k_block++) {
  342. cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
  343. atom.accumulate_ = decltype(atom.accumulate_)::One;
  344. }
  345. }
  346. ////////////////////////////////////////////////////////////////////////////////////////////////////
  347. template <class a_type, class b_type, class c_type,
  348. int M, int N, UMMA::Major a_major, UMMA::Major b_major,
  349. UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
  350. CUTE_HOST_DEVICE constexpr
  351. auto
  352. to_tiled_mma_sm100_ts(
  353. TiledMMA<MMA_Atom<
  354. MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
  355. cute::C<M>, cute::C<N>,
  356. cute::integral_constant<UMMA::Major, a_major>,
  357. cute::integral_constant<UMMA::Major, b_major>,
  358. cute::integral_constant<UMMA::ScaleIn, a_neg>,
  359. cute::integral_constant<UMMA::ScaleIn, b_neg>>,
  360. TAs...>, TMs...>) {
  361. return TiledMMA<MMA_Atom<
  362. MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
  363. M, N,
  364. a_major, b_major,
  365. a_neg, b_neg, UMMA::Saturate::False>>,
  366. TAs...>, TMs...>{};
  367. }
  368. template <class a_type, class b_type, class c_type,
  369. int M, int N, UMMA::Major a_major, UMMA::Major b_major,
  370. UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
  371. CUTE_HOST_DEVICE constexpr
  372. auto
  373. to_tiled_mma_sm100_ts(
  374. TiledMMA<MMA_Atom<
  375. SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
  376. M, N,
  377. a_major,
  378. b_major,
  379. a_neg,
  380. b_neg>,
  381. TAs...>, TMs...>) {
  382. return TiledMMA<MMA_Atom<
  383. SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
  384. M, N,
  385. a_major, b_major,
  386. a_neg, b_neg, UMMA::Saturate::False>,
  387. TAs...>, TMs...>{};
  388. }
  389. ////////////////////////////////////////////////////////////////////////////////////////////////////
  390. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  391. class CopyAtom, class TV, class Tiler, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  392. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  393. CUTLASS_DEVICE void copy(TiledCopy<CopyAtom, TV, Tiler> const &tiled_copy, Tensor<Engine0, Layout0> const &S,
  394. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  395. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  396. // Decay TiledCopy to CopyAtom
  397. auto copy_atom = static_cast<CopyAtom const&>(tiled_copy);
  398. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  399. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  400. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  401. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  402. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  403. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  404. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  405. auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
  406. #pragma unroll
  407. for (int m = 0; m < size<1>(S); ++m) {
  408. bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN;
  409. if constexpr (Is_even_MN || !Clear_OOB_MN) {
  410. if (Is_even_MN || predicate_mn) {
  411. #pragma unroll
  412. for (int k = 0; k < size<2>(S); ++k) {
  413. if constexpr (Is_even_K || !Clear_OOB_K) {
  414. if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); }
  415. } else { // Clear_OOB_K == true && Is_even_K == false
  416. // If copy traits can be transformed with a predicate value, do it, otherwise branch here
  417. if constexpr (has_with_bool) {
  418. cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k));
  419. } else {
  420. if (predicate_K(k)) {
  421. cute::copy(copy_atom, S(_, m, k), D(_, m, k));
  422. } else {
  423. cute::clear(D(_, m, k));
  424. }
  425. }
  426. }
  427. }
  428. }
  429. } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true
  430. if constexpr (!has_with_bool) {
  431. if (predicate_mn) {
  432. #pragma unroll
  433. for (int k = 0; k < size<2>(S); ++k) {
  434. if (Is_even_K || predicate_K(k)) {
  435. cute::copy(copy_atom, S(_, m, k), D(_, m, k));
  436. } else if (Clear_OOB_K) {
  437. cute::clear(D(_, m, k));
  438. }
  439. }
  440. } else {
  441. cute::clear(D(_, m, _));
  442. }
  443. } else { // combine the mn predicate with the k predicate
  444. #pragma unroll
  445. for (int k = 0; k < size<2>(S); ++k) {
  446. cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k));
  447. }
  448. }
  449. }
  450. }
  451. }
  452. ////////////////////////////////////////////////////////////////////////////////////////////////////
  453. // Byte permute and shuffle to match register layout of
  454. // (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
  455. template <typename Fragment>
  456. CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) {
  457. // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits
  458. static_assert(decltype(size<0, 0>(frag))::value == 4);
  459. static_assert(decltype(size<0, 1>(frag))::value == 2);
  460. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  461. static_assert(decltype(stride<0, 1>(frag))::value == 4);
  462. static_assert(sizeof(typename Fragment::value_type) == 1);
  463. int quad_idx = threadIdx.x % 4;
  464. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  465. int selector_upper = lane_03 ? 0x5410 : 0x1054;
  466. int selector_lower = lane_03 ? 0x7632 : 0x3276;
  467. static constexpr int upper_map[4] = {0, 3, 1, 2};
  468. // static constexpr int lower_map[4] = {1, 2, 0, 3};
  469. Tensor frag_64b = recast<uint2>(frag); // ((1, 1, 2), MMA_M, MMA_N)
  470. #pragma unroll
  471. for (int i = 0; i < size(frag_64b); ++i) {
  472. uint32_t upper = frag_64b[i].x;
  473. uint32_t lower = frag_64b[i].y;
  474. uint32_t upper0 = lane_03 ? upper : lower;
  475. uint32_t lower0 = lane_03 ? lower : upper;
  476. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  477. // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  478. lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4);
  479. frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper);
  480. frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower);
  481. }
  482. }
  483. ////////////////////////////////////////////////////////////////////////////////////////////////////
  484. template <typename Fragment>
  485. CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
  486. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  487. static_assert(decltype(size<0, 0>(frag))::value == 2);
  488. static_assert(decltype(size<0, 1>(frag))::value == 2);
  489. static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);
  490. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  491. static_assert(sizeof(typename Fragment::value_type) == 4);
  492. Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  493. #pragma unroll
  494. for (int mi = 0; mi < size<1>(frag_64b); ++mi) {
  495. #pragma unroll
  496. for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {
  497. cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));
  498. }
  499. }
  500. }
  501. ////////////////////////////////////////////////////////////////////////////////////////////////////
  502. template <typename Fragment>
  503. CUTLASS_DEVICE void permute_output_fp8(Fragment &out) {
  504. // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  505. static_assert(decltype(size<0, 0>(out))::value == 2);
  506. static_assert(decltype(size<0, 1>(out))::value == 2);
  507. static_assert(decltype(size<0, 2>(out))::value % 2 == 0);
  508. static_assert(decltype(stride<0, 0>(out))::value == 1);
  509. static_assert(sizeof(typename Fragment::value_type) == 4);
  510. Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N))
  511. #pragma unroll
  512. for (int mi = 0; mi < size<1>(frag); ++mi) {
  513. #pragma unroll
  514. for (int j = 0; j < size<0, 1>(frag); ++j) {
  515. #pragma unroll
  516. for (int i = 0; i < size<0, 2>(frag) / 2; ++i) {
  517. cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi));
  518. }
  519. }
  520. }
  521. }
  522. ////////////////////////////////////////////////////////////////////////////////////////////////////
  523. template <typename Fragment>
  524. CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) {
  525. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits
  526. static_assert(decltype(size<0, 0>(frag))::value == 2);
  527. static_assert(decltype(size<0, 1>(frag))::value == 2);
  528. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  529. static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4);
  530. int quad_idx = threadIdx.x % 4;
  531. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  532. static constexpr int upper_map[4] = {0, 2, 3, 1};
  533. // static constexpr int lower_map[4] = {2, 0, 1, 3};
  534. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  535. using type2 = std::conditional_t<sizeof(typename Fragment::value_type) == 2, uint32_t, uint64_t>;
  536. Tensor frag_2 = group_modes<1, 3>(recast<type2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  537. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); }
  538. #pragma unroll
  539. for (int mi = 0; mi < size<1>(frag_2); ++mi) {
  540. #pragma unroll
  541. for (int j = 0; j < size<0, 1>(frag_2); ++j) {
  542. #pragma unroll
  543. for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) {
  544. type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi);
  545. type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi);
  546. type2 upper0 = lane_03 ? upper : lower;
  547. type2 lower0 = lane_03 ? lower : upper;
  548. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  549. // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  550. lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4);
  551. frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0;
  552. frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0;
  553. }
  554. }
  555. }
  556. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  557. }
  558. ////////////////////////////////////////////////////////////////////////////////////////////////////
  559. template <typename Engine, typename Layout>
  560. CUTLASS_DEVICE void apply_softcap(Tensor<Engine, Layout> &tensor, float const softcap){
  561. #pragma unroll
  562. for (int i = 0; i < size(tensor); ++i) {
  563. tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
  564. }
  565. }
  566. template <typename Engine, typename Layout>
  567. CUTLASS_DEVICE auto calculate_dtanh(Tensor<Engine, Layout> &tensor){
  568. Tensor out = make_fragment_like<float>(tensor);
  569. #pragma unroll
  570. for (int i = 0; i < size(tensor); ++i) {
  571. out(i) = 1.f - (tensor(i) * tensor(i));
  572. }
  573. return out;
  574. }
  575. ////////////////////////////////////////////////////////////////////////////////////////////////////
  576. template<class T>
  577. CUTE_DEVICE T warp_prefix_sum(T val) {
  578. int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
  579. CUTLASS_PRAGMA_UNROLL
  580. for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) {
  581. T partial_sum = __shfl_up_sync(0xffffffff, val, i);
  582. if (lane >= i) { val += partial_sum; }
  583. }
  584. return val;
  585. }
  586. ////////////////////////////////////////////////////////////////////////////////////////////////////
  587. template<class T>
  588. CUTE_DEVICE T warp_uniform(T a) {
  589. return __shfl_sync(0xffffffff, a, 0);
  590. }
  591. ////////////////////////////////////////////////////////////////////////////////////////////////////
  592. CUTLASS_DEVICE
  593. int canonical_warp_group_idx_nosync() {
  594. return threadIdx.x / cutlass::NumThreadsPerWarpGroup;
  595. }
  596. ////////////////////////////////////////////////////////////////////////////////////////////////////
  597. } // namespace flash