utils.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cutlass/array.h>
  14. #include <cutlass/cutlass.h>
  15. #include <cutlass/numeric_conversion.h>
  16. #include <cutlass/numeric_types.h>
  17. ////////////////////////////////////////////////////////////////////////////////////////////////////
  18. namespace flash {
  19. ////////////////////////////////////////////////////////////////////////////////////////////////////
  20. template<typename T>
  21. __forceinline__ __device__ uint32_t relu2(const uint32_t x);
  22. template<>
  23. __forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
  24. uint32_t res;
  25. const uint32_t zero = 0u;
  26. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  27. asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
  28. #else
  29. asm volatile( \
  30. "{\n" \
  31. "\t .reg .f16x2 sela;\n" \
  32. "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
  33. "\t and.b32 %0, sela, %1;\n"
  34. "}\n" : "=r"(res) : "r"(x), "r"(zero));
  35. #endif
  36. return res;
  37. }
  38. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  39. template<>
  40. __forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
  41. uint32_t res;
  42. const uint32_t zero = 0u;
  43. asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
  44. return res;
  45. }
  46. #endif
  47. ////////////////////////////////////////////////////////////////////////////////////////////////////
  48. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  49. template<typename T>
  50. __forceinline__ __device__ uint32_t convert_relu2(const float2 x);
  51. template<>
  52. __forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
  53. uint32_t res;
  54. const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
  55. const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
  56. asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
  57. return res;
  58. }
  59. template<>
  60. __forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
  61. uint32_t res;
  62. const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
  63. const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
  64. asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
  65. return res;
  66. }
  67. #endif
  68. ////////////////////////////////////////////////////////////////////////////////////////////////////
  69. template<typename T>
  70. struct MaxOp {
  71. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  72. };
  73. template <>
  74. struct MaxOp<float> {
  75. // This is slightly faster
  76. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  77. };
  78. ////////////////////////////////////////////////////////////////////////////////////////////////////
  79. template<typename T>
  80. struct SumOp {
  81. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  82. };
  83. ////////////////////////////////////////////////////////////////////////////////////////////////////
  84. template<int THREADS>
  85. struct Allreduce {
  86. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  87. template<typename T, typename Operator>
  88. static __device__ __forceinline__ T run(T x, Operator &op) {
  89. constexpr int OFFSET = THREADS / 2;
  90. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  91. return Allreduce<OFFSET>::run(x, op);
  92. }
  93. };
  94. ////////////////////////////////////////////////////////////////////////////////////////////////////
  95. template<>
  96. struct Allreduce<2> {
  97. template<typename T, typename Operator>
  98. static __device__ __forceinline__ T run(T x, Operator &op) {
  99. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  100. return x;
  101. }
  102. };
  103. ////////////////////////////////////////////////////////////////////////////////////////////////////
  104. template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
  105. typename Tensor2, typename Tensor3, typename Tensor4,
  106. typename TiledMma, typename TiledCopyA, typename TiledCopyB,
  107. typename ThrCopyA, typename ThrCopyB>
  108. __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
  109. Tensor4 const& tCsB, TiledMma tiled_mma,
  110. TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
  111. ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
  112. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  113. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  114. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  115. Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
  116. CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
  117. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  118. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  119. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
  120. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
  121. #pragma unroll
  122. for (int i = 0; i < size<2>(tCrA); ++i) {
  123. if (i < size<2>(tCrA) - 1) {
  124. if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
  125. if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
  126. }
  127. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  128. }
  129. }
  130. ////////////////////////////////////////////////////////////////////////////////////////////////////
  131. template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
  132. typename TiledMma, typename TiledCopy, typename ThrCopy>
  133. __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
  134. TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
  135. ThrCopy smem_thr_copy_B) {
  136. CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
  137. CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
  138. CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
  139. Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
  140. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
  141. cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
  142. #pragma unroll
  143. for (int i = 0; i < size<2>(tCrA); ++i) {
  144. if (i < size<2>(tCrA) - 1) {
  145. cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
  146. }
  147. cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  148. }
  149. }
  150. ////////////////////////////////////////////////////////////////////////////////////////////////////
  151. // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  152. template<typename Layout>
  153. __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
  154. static_assert(decltype(size<0>(acc_layout))::value == 4);
  155. static_assert(decltype(rank(acc_layout))::value == 3);
  156. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  157. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  158. };
  159. ////////////////////////////////////////////////////////////////////////////////////////////////////
  160. // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  161. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  162. template<typename MMA_traits, typename Layout>
  163. __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
  164. using X = Underscore;
  165. static_assert(decltype(size<0>(acc_layout))::value == 4);
  166. static_assert(decltype(rank(acc_layout))::value == 3);
  167. constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
  168. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  169. if constexpr (mma_shape_K == 8) {
  170. return acc_layout;
  171. } else {
  172. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  173. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  174. }
  175. };
  176. ////////////////////////////////////////////////////////////////////////////////////////////////////
  177. // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  178. template<typename Layout>
  179. __forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
  180. using X = Underscore;
  181. static_assert(decltype(size<0>(acc_layout))::value == 4);
  182. static_assert(decltype(rank(acc_layout))::value == 3);
  183. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  184. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  185. };
  186. ////////////////////////////////////////////////////////////////////////////////////////////////////
  187. template <typename To_type, typename Engine, typename Layout>
  188. __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
  189. using From_type = typename Engine::value_type;
  190. constexpr int numel = decltype(size(tensor))::value;
  191. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  192. // HACK: this requires tensor to be "contiguous"
  193. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  194. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  195. }
  196. ////////////////////////////////////////////////////////////////////////////////////////////////////
  197. template <typename Engine, typename Layout>
  198. __forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
  199. constexpr int numel = decltype(size(tensor))::value;
  200. static_assert(numel % 2 == 0);
  201. using value_t = typename Engine::value_type;
  202. // HACK: this requires tensor to be "contiguous"
  203. Tensor tensor_uint32 = recast<uint32_t>(tensor);
  204. #pragma unroll
  205. for (int i = 0; i < size(tensor_uint32); ++i) {
  206. tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
  207. }
  208. }
  209. ////////////////////////////////////////////////////////////////////////////////////////////////////
  210. // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
  211. template <typename To_type, typename Engine, typename Layout>
  212. __forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
  213. using From_type = typename Engine::value_type;
  214. static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
  215. static_assert(std::is_same_v<float, From_type>);
  216. constexpr int numel = decltype(size(tensor))::value;
  217. static_assert(numel % 2 == 0);
  218. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  219. // HACK: this requires tensor to be "contiguous"
  220. Tensor tensor_float2 = recast<float2>(tensor);
  221. Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
  222. #pragma unroll
  223. for (int i = 0; i < size(out_uint32); ++i) {
  224. out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
  225. }
  226. Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
  227. #else
  228. Tensor out = flash::convert_type<To_type>(tensor);
  229. flash::relu_(out);
  230. #endif
  231. return out;
  232. }
  233. ////////////////////////////////////////////////////////////////////////////////////////////////////
  234. // Blocks until all but N previous cp.async.commit_group operations have committed.
  235. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
  236. // (which is equivalent to commit_group then wait_group 0).
  237. // Instead we just call cp.async.wait_group 0, which is slightly faster.
  238. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
  239. template <int N>
  240. CUTE_HOST_DEVICE
  241. void cp_async_wait() {
  242. #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
  243. asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
  244. #endif
  245. }
  246. ////////////////////////////////////////////////////////////////////////////////////////////////////
  247. // resolves offset of a slice of a paged kv copy from gmem.
  248. // assumes that the tensor has already been positioned at the correct head.
  249. template <typename Kernel_traits>
  250. __forceinline__ __device__
  251. int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
  252. const int* block_table, const int page_stride, const int row_stride) {
  253. constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
  254. constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
  255. constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
  256. constexpr int kBlockN = Kernel_traits::kBlockN;
  257. const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
  258. const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
  259. const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
  260. const int64_t page_offset = global_row_offset % page_block_size;
  261. const int64_t virtual_page_idx = global_row_offset / page_block_size;
  262. return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
  263. + page_offset * ((int64_t) row_stride)
  264. + col_offset;
  265. }
  266. ////////////////////////////////////////////////////////////////////////////////////////////////////
  267. // Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
  268. // where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
  269. // that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
  270. template <class Shape, class Stride>
  271. __forceinline__ __device__
  272. auto reshape_thread_tile(Layout<Shape, Stride> l) {
  273. return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
  274. append(get<0>(l.stride()), get<2>(l.stride())));
  275. }
  276. // reshapes and flattens the thread tile layout. A separate function is needed for the case where
  277. // one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
  278. // for the case of swizzled layouts
  279. template <class Shape, class Stride>
  280. __forceinline__ __device__
  281. auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
  282. auto mode_0 = filter(flatten(get<0>(l)));
  283. return make_layout(append(mode_0.shape(), get<2>(l.shape())),
  284. append(mode_0.stride(), get<2>(l.stride())));
  285. }
  286. ////////////////////////////////////////////////////////////////////////////////////////////////////
  287. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  288. typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  289. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  290. __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
  291. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  292. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  293. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  294. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  295. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  296. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  297. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  298. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  299. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  300. #pragma unroll
  301. for (int m = 0; m < size<1>(S); ++m) {
  302. if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  303. #pragma unroll
  304. for (int k = 0; k < size<2>(S); ++k) {
  305. if (Is_even_K || predicate_K(k)) {
  306. cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
  307. } else if (Clear_OOB_K) {
  308. cute::clear(D(_, m, k));
  309. }
  310. }
  311. } else if (Clear_OOB_MN) {
  312. cute::clear(D(_, m, _));
  313. }
  314. }
  315. // TD [2023-04-13]: Strange that the code below can cause race condition.
  316. // I think it's because the copies are under an if statement.
  317. // if (Is_even_K) {
  318. // #pragma unroll
  319. // for (int m = 0; m < size<1>(S); ++m) {
  320. // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  321. // copy(tiled_copy, S(_, m, _), D(_, m, _));
  322. // } else if (Clear_OOB_MN) {
  323. // clear(D(_, m, _));
  324. // }
  325. // }
  326. // } else { // It's slightly faster in this case if iterate over K first
  327. // #pragma unroll
  328. // for (int k = 0; k < size<2>(S); ++k) {
  329. // if (predicate_K(k)) {
  330. // #pragma unroll
  331. // for (int m = 0; m < size<1>(S); ++m) {
  332. // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  333. // copy(tiled_copy, S(_, m, k), D(_, m, k));
  334. // } else if (Clear_OOB_MN) {
  335. // clear(D(_, m, k));
  336. // }
  337. // }
  338. // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
  339. // if (Clear_OOB_MN || Is_even_MN) {
  340. // clear(D(_, _, k));
  341. // } else {
  342. // #pragma unroll
  343. // for (int m = 0; m < size<1>(S); ++m) {
  344. // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
  345. // clear(D(_, m, k));
  346. // }
  347. // }
  348. // }
  349. // }
  350. // }
  351. // }
  352. }
  353. ////////////////////////////////////////////////////////////////////////////////////////////////////
  354. template <bool Is_even_K=true,
  355. typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  356. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  357. __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
  358. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  359. Tensor<Engine3, Layout3> const &predicate_K,
  360. const int max_MN=0, const int min_MN=0) {
  361. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  362. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  363. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  364. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  365. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  366. // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
  367. #pragma unroll
  368. for (int m = 0; m < size<1>(S); ++m) {
  369. // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
  370. if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
  371. // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
  372. #pragma unroll
  373. for (int k = 0; k < size<2>(S); ++k) {
  374. if (Is_even_K || predicate_K(k)) {
  375. cute::copy(S(_, m, k), D(_, m, k));
  376. }
  377. }
  378. }
  379. }
  380. }
  381. ////////////////////////////////////////////////////////////////////////////////////////////////////
  382. } // namespace flash