utils.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cutlass/cutlass.h>
  14. #include <cutlass/array.h>
  15. #include <cutlass/numeric_conversion.h>
  16. #include <cutlass/numeric_types.h>
  17. #define CHECK_CUDA(call) \
  18. do { \
  19. cudaError_t status_ = call; \
  20. if (status_ != cudaSuccess) { \
  21. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  22. exit(1); \
  23. } \
  24. } while(0)
  25. #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
  26. namespace flash {
  27. using namespace cute;
  28. ////////////////////////////////////////////////////////////////////////////////////////////////////
  29. // A wrapper for the kernel that is used to guard against compilation on
  30. // architectures that will never use the kernel. The purpose of this is to
  31. // reduce the size of the compiled binary.
  32. // Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55
  33. template <typename Kernel>
  34. struct enable_sm90_or_later : Kernel {
  35. template <typename... Args>
  36. CUTLASS_DEVICE void operator()(Args&&... args) {
  37. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
  38. Kernel::operator()(std::forward<Args>(args)...);
  39. #endif
  40. }
  41. };
  42. template <typename Kernel>
  43. struct enable_sm80_to_sm89 : Kernel {
  44. template <typename... Args>
  45. CUTLASS_DEVICE void operator()(Args&&... args) {
  46. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890)
  47. Kernel::operator()(std::forward<Args>(args)...);
  48. #endif
  49. }
  50. };
  51. ////////////////////////////////////////////////////////////////////////////////////////////////////
  52. template<typename T>
  53. struct MaxOp {
  54. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  55. };
  56. template <>
  57. struct MaxOp<float> {
  58. // This is slightly faster
  59. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  60. };
  61. ////////////////////////////////////////////////////////////////////////////////////////////////////
  62. template<typename T>
  63. struct SumOp {
  64. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  65. };
  66. ////////////////////////////////////////////////////////////////////////////////////////////////////
  67. template<int THREADS>
  68. struct Allreduce {
  69. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  70. template<typename T, typename Operator>
  71. static __device__ __forceinline__ T run(T x, Operator &op) {
  72. constexpr int OFFSET = THREADS / 2;
  73. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  74. return Allreduce<OFFSET>::run(x, op);
  75. }
  76. };
  77. ////////////////////////////////////////////////////////////////////////////////////////////////////
  78. template<>
  79. struct Allreduce<2> {
  80. template<typename T, typename Operator>
  81. static __device__ __forceinline__ T run(T x, Operator &op) {
  82. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  83. return x;
  84. }
  85. };
  86. ////////////////////////////////////////////////////////////////////////////////////////////////////
  87. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  88. // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  89. template<bool Transposed=false, typename Layout0>
  90. CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) {
  91. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  92. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  93. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  94. static_assert(decltype(rank(acc_layout))::value == 3);
  95. auto l = acc_layout;
  96. if constexpr (!Transposed) {
  97. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
  98. } else {
  99. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  100. }
  101. } else { // SM80
  102. static_assert(decltype(size<0>(acc_layout))::value == 4);
  103. static_assert(decltype(rank(acc_layout))::value == 3);
  104. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  105. if constexpr (!Transposed) {
  106. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  107. } else {
  108. return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  109. }
  110. }
  111. };
  112. ////////////////////////////////////////////////////////////////////////////////////////////////////
  113. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  114. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  115. // For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
  116. // For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
  117. template<typename MMA_Traits, typename Layout0>
  118. CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) {
  119. using X = Underscore;
  120. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  121. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  122. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  123. static_assert(decltype(rank(acc_layout))::value == 3);
  124. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  125. if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
  126. auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
  127. return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  128. } else {
  129. static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
  130. static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
  131. static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
  132. auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
  133. // This combines the first two modes (<0, 0> and <0, 1>) into one mode.
  134. // Will require register shuffling later to be correct.
  135. return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
  136. get<1>(acc_layout),
  137. coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
  138. // This combination is right but doesn't work with register shuffling.
  139. // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
  140. // get<1>(acc_layout),
  141. // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  142. }
  143. } else { // SM80
  144. static_assert(decltype(size<0>(acc_layout))::value == 4);
  145. static_assert(decltype(rank(acc_layout))::value == 3);
  146. constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
  147. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  148. if constexpr (mma_shape_K == 8) {
  149. return acc_layout;
  150. } else {
  151. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  152. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  153. }
  154. }
  155. };
  156. ////////////////////////////////////////////////////////////////////////////////////////////////////
  157. template <typename To_type, typename Engine, typename Layout>
  158. CUTLASS_DEVICE auto convert_type_unsafe(Tensor<Engine, Layout> const &tensor) {
  159. using From_type = typename Engine::value_type;
  160. static constexpr int numel = decltype(size(tensor))::value;
  161. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  162. // HACK: this requires tensor to be "contiguous"
  163. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  164. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  165. // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not
  166. // inline this function, then the memory might not be valid.
  167. }
  168. ////////////////////////////////////////////////////////////////////////////////////////////////////
  169. template <typename Engine, typename Layout, typename EngineOut>
  170. CUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {
  171. // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.
  172. using From_type = typename Engine::value_type;
  173. using To_type = typename EngineOut::value_type;
  174. static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));
  175. static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
  176. Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);
  177. Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
  178. static_assert(size(frag) == size(out_frg));
  179. cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
  180. #pragma unroll
  181. for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }
  182. }
  183. ////////////////////////////////////////////////////////////////////////////////////////////////////
  184. // Blocks until all but N previous cp.async.commit_group operations have committed.
  185. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
  186. // (which is equivalent to commit_group then wait_group 0).
  187. // Instead we just call cp.async.wait_group 0, which is slightly faster.
  188. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
  189. template <int N>
  190. CUTE_HOST_DEVICE
  191. void cp_async_wait() {
  192. #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
  193. asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
  194. #endif
  195. }
  196. ////////////////////////////////////////////////////////////////////////////////////////////////////
  197. template <bool A, class Mma, class Tensor0>
  198. CUTLASS_DEVICE
  199. auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
  200. if constexpr (A) {
  201. return mma.partition_fragment_A(tensor0);
  202. } else {
  203. return mma.partition_fragment_B(tensor0);
  204. }
  205. }
  206. ////////////////////////////////////////////////////////////////////////////////////////////////////
  207. template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
  208. typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
  209. CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
  210. if constexpr (M_slice >= 0) {
  211. static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
  212. static_assert(M_slice < MMA_M);
  213. // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N)
  214. Tensor tCrC_slice = cute::logical_divide(tCrC, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  215. if constexpr (!SwapAB) {
  216. Tensor tCrA_slice = cute::logical_divide(tCrA, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  217. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA_slice, tCrB, tCrC_slice);
  218. } else {
  219. Tensor tCrB_slice = cute::logical_divide(tCrB, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  220. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA, tCrB_slice, tCrC_slice);
  221. }
  222. } else {
  223. constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
  224. // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
  225. if constexpr (Is_RS) {
  226. if constexpr (!SwapAB) {
  227. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  228. } else {
  229. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  230. }
  231. }
  232. warpgroup_fence_operand(tCrC);
  233. warpgroup_arrive();
  234. if constexpr (zero_init) {
  235. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  236. }
  237. // Unroll the K mode manually to set scale D to 1
  238. CUTLASS_PRAGMA_UNROLL
  239. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  240. if constexpr (!SwapAB) {
  241. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  242. } else {
  243. cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC);
  244. }
  245. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  246. }
  247. warpgroup_commit_batch();
  248. if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
  249. warpgroup_fence_operand(tCrC);
  250. if constexpr (Is_RS) {
  251. if constexpr (!SwapAB) {
  252. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  253. } else {
  254. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  255. }
  256. }
  257. }
  258. }
  259. ////////////////////////////////////////////////////////////////////////////////////////////////////
  260. template<bool A_in_regs=false, bool B_in_regs=false, bool SwapAB=false,
  261. typename Tensor0, typename Tensor1,
  262. typename Tensor2, typename Tensor3, typename Tensor4,
  263. typename TiledMma, typename TiledCopyA, typename TiledCopyB,
  264. typename ThrCopyA, typename ThrCopyB, typename Hook>
  265. CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
  266. Tensor4 const& tCsB, TiledMma tiled_mma,
  267. TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
  268. ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) {
  269. if constexpr (SwapAB) {
  270. gemm_sm80<B_in_regs, A_in_regs>(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn);
  271. } else {
  272. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  273. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  274. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  275. Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
  276. CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
  277. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  278. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  279. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
  280. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
  281. #pragma unroll
  282. for (int i = 0; i < size<2>(tCrA); ++i) {
  283. if (i < size<2>(tCrA) - 1) {
  284. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
  285. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
  286. }
  287. if constexpr (!std::is_same_v<Hook, std::nullptr_t>) {
  288. if (i == 0) { fn(); }
  289. }
  290. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  291. }
  292. }
  293. }
  294. ////////////////////////////////////////////////////////////////////////////////////////////////////
  295. template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
  296. typename TiledMma, typename TiledCopy, typename ThrCopy>
  297. CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
  298. TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
  299. ThrCopy smem_thr_copy_B) {
  300. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  301. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  302. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  303. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  304. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  305. cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
  306. #pragma unroll
  307. for (int i = 0; i < size<2>(tCrA); ++i) {
  308. if (i < size<2>(tCrA) - 1) {
  309. cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
  310. }
  311. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  312. }
  313. }
  314. ////////////////////////////////////////////////////////////////////////////////////////////////////
  315. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  316. class CopyAtom, class TV, class Tiler, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  317. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  318. CUTLASS_DEVICE void copy(TiledCopy<CopyAtom, TV, Tiler> const &tiled_copy, Tensor<Engine0, Layout0> const &S,
  319. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  320. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  321. // Decay TiledCopy to CopyAtom
  322. auto copy_atom = static_cast<CopyAtom const&>(tiled_copy);
  323. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  324. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  325. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  326. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  327. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  328. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  329. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  330. auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
  331. #pragma unroll
  332. for (int m = 0; m < size<1>(S); ++m) {
  333. bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN;
  334. if constexpr (Is_even_MN || !Clear_OOB_MN) {
  335. if (Is_even_MN || predicate_mn) {
  336. #pragma unroll
  337. for (int k = 0; k < size<2>(S); ++k) {
  338. if constexpr (Is_even_K || !Clear_OOB_K) {
  339. if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); }
  340. } else { // Clear_OOB_K == true && Is_even_K == false
  341. // If copy traits can be transformed with a predicate value, do it, otherwise branch here
  342. if constexpr (has_with_bool) {
  343. cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k));
  344. } else {
  345. if (predicate_K(k)) {
  346. cute::copy(copy_atom, S(_, m, k), D(_, m, k));
  347. } else {
  348. cute::clear(D(_, m, k));
  349. }
  350. }
  351. }
  352. }
  353. }
  354. } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true
  355. if constexpr (!has_with_bool) {
  356. if (predicate_mn) {
  357. #pragma unroll
  358. for (int k = 0; k < size<2>(S); ++k) {
  359. if (Is_even_K || predicate_K(k)) {
  360. cute::copy(copy_atom, S(_, m, k), D(_, m, k));
  361. } else if (Clear_OOB_K) {
  362. cute::clear(D(_, m, k));
  363. }
  364. }
  365. } else {
  366. cute::clear(D(_, m, _));
  367. }
  368. } else { // combine the mn predicate with the k predicate
  369. #pragma unroll
  370. for (int k = 0; k < size<2>(S); ++k) {
  371. cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k));
  372. }
  373. }
  374. }
  375. }
  376. }
  377. ////////////////////////////////////////////////////////////////////////////////////////////////////
  378. // Byte permute and shuffle to match register layout of
  379. // (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
  380. template <typename Fragment>
  381. CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) {
  382. // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits
  383. static_assert(decltype(size<0, 0>(frag))::value == 4);
  384. static_assert(decltype(size<0, 1>(frag))::value == 2);
  385. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  386. static_assert(decltype(stride<0, 1>(frag))::value == 4);
  387. static_assert(sizeof(typename Fragment::value_type) == 1);
  388. int quad_idx = threadIdx.x % 4;
  389. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  390. int selector_upper = lane_03 ? 0x5410 : 0x1054;
  391. int selector_lower = lane_03 ? 0x7632 : 0x3276;
  392. static constexpr int upper_map[4] = {0, 3, 1, 2};
  393. // static constexpr int lower_map[4] = {1, 2, 0, 3};
  394. Tensor frag_64b = recast<uint2>(frag); // ((1, 1, 2), MMA_M, MMA_N)
  395. #pragma unroll
  396. for (int i = 0; i < size(frag_64b); ++i) {
  397. uint32_t upper = frag_64b[i].x;
  398. uint32_t lower = frag_64b[i].y;
  399. uint32_t upper0 = lane_03 ? upper : lower;
  400. uint32_t lower0 = lane_03 ? lower : upper;
  401. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  402. // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  403. lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4);
  404. frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper);
  405. frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower);
  406. }
  407. }
  408. ////////////////////////////////////////////////////////////////////////////////////////////////////
  409. template <typename Fragment>
  410. CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
  411. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  412. static_assert(decltype(size<0, 0>(frag))::value == 2);
  413. static_assert(decltype(size<0, 1>(frag))::value == 2);
  414. static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);
  415. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  416. static_assert(sizeof(typename Fragment::value_type) == 4);
  417. Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  418. #pragma unroll
  419. for (int mi = 0; mi < size<1>(frag_64b); ++mi) {
  420. #pragma unroll
  421. for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {
  422. cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));
  423. }
  424. }
  425. }
  426. ////////////////////////////////////////////////////////////////////////////////////////////////////
  427. template <typename Fragment>
  428. CUTLASS_DEVICE void permute_output_fp8(Fragment &out) {
  429. // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  430. static_assert(decltype(size<0, 0>(out))::value == 2);
  431. static_assert(decltype(size<0, 1>(out))::value == 2);
  432. static_assert(decltype(size<0, 2>(out))::value % 2 == 0);
  433. static_assert(decltype(stride<0, 0>(out))::value == 1);
  434. static_assert(sizeof(typename Fragment::value_type) == 4);
  435. Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N))
  436. #pragma unroll
  437. for (int mi = 0; mi < size<1>(frag); ++mi) {
  438. #pragma unroll
  439. for (int j = 0; j < size<0, 1>(frag); ++j) {
  440. #pragma unroll
  441. for (int i = 0; i < size<0, 2>(frag) / 2; ++i) {
  442. cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi));
  443. }
  444. }
  445. }
  446. }
  447. ////////////////////////////////////////////////////////////////////////////////////////////////////
  448. template <typename Fragment>
  449. CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) {
  450. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits
  451. static_assert(decltype(size<0, 0>(frag))::value == 2);
  452. static_assert(decltype(size<0, 1>(frag))::value == 2);
  453. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  454. static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4);
  455. int quad_idx = threadIdx.x % 4;
  456. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  457. static constexpr int upper_map[4] = {0, 2, 3, 1};
  458. // static constexpr int lower_map[4] = {2, 0, 1, 3};
  459. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  460. using type2 = std::conditional_t<sizeof(typename Fragment::value_type) == 2, uint32_t, uint64_t>;
  461. Tensor frag_2 = group_modes<1, 3>(recast<type2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  462. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); }
  463. #pragma unroll
  464. for (int mi = 0; mi < size<1>(frag_2); ++mi) {
  465. #pragma unroll
  466. for (int j = 0; j < size<0, 1>(frag_2); ++j) {
  467. #pragma unroll
  468. for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) {
  469. type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi);
  470. type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi);
  471. type2 upper0 = lane_03 ? upper : lower;
  472. type2 lower0 = lane_03 ? lower : upper;
  473. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  474. // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  475. lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4);
  476. frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0;
  477. frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0;
  478. }
  479. }
  480. }
  481. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  482. }
  483. ////////////////////////////////////////////////////////////////////////////////////////////////////
  484. template <typename Engine, typename Layout>
  485. CUTLASS_DEVICE void apply_softcap(Tensor<Engine, Layout> &tensor, float const softcap){
  486. #pragma unroll
  487. for (int i = 0; i < size(tensor); ++i) {
  488. tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
  489. }
  490. }
  491. template <typename Engine, typename Layout>
  492. CUTLASS_DEVICE auto calculate_dtanh(Tensor<Engine, Layout> &tensor){
  493. Tensor out = make_fragment_like<float>(tensor);
  494. #pragma unroll
  495. for (int i = 0; i < size(tensor); ++i) {
  496. out(i) = 1.f - (tensor(i) * tensor(i));
  497. }
  498. return out;
  499. }
  500. ////////////////////////////////////////////////////////////////////////////////////////////////////
  501. CUTLASS_DEVICE
  502. int canonical_warp_group_idx_nosync() {
  503. return threadIdx.x / cutlass::NumThreadsPerWarpGroup;
  504. }
  505. ////////////////////////////////////////////////////////////////////////////////////////////////////
  506. } // namespace flash