utils.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <assert.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #include <cuda_fp16.h>
  9. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  10. #include <cuda_bf16.h>
  11. #endif
  12. #include <cute/tensor.hpp>
  13. #include <cutlass/array.h>
  14. #include <cutlass/cutlass.h>
  15. #include <cutlass/numeric_conversion.h>
  16. #include <cutlass/numeric_types.h>
  17. #define CHECK_CUDA(call) \
  18. do { \
  19. cudaError_t status_ = call; \
  20. if (status_ != cudaSuccess) { \
  21. fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
  22. exit(1); \
  23. } \
  24. } while(0)
  25. #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
  26. namespace flash {
  27. using namespace cute;
  28. ////////////////////////////////////////////////////////////////////////////////////////////////////
  29. template<typename T>
  30. struct MaxOp {
  31. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
  32. };
  33. template <>
  34. struct MaxOp<float> {
  35. // This is slightly faster
  36. __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
  37. };
  38. ////////////////////////////////////////////////////////////////////////////////////////////////////
  39. template<typename T>
  40. struct SumOp {
  41. __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
  42. };
  43. ////////////////////////////////////////////////////////////////////////////////////////////////////
  44. template<int THREADS>
  45. struct Allreduce {
  46. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  47. template<typename T, typename Operator>
  48. static __device__ __forceinline__ T run(T x, Operator &op) {
  49. constexpr int OFFSET = THREADS / 2;
  50. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  51. return Allreduce<OFFSET>::run(x, op);
  52. }
  53. };
  54. ////////////////////////////////////////////////////////////////////////////////////////////////////
  55. template<>
  56. struct Allreduce<2> {
  57. template<typename T, typename Operator>
  58. static __device__ __forceinline__ T run(T x, Operator &op) {
  59. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  60. return x;
  61. }
  62. };
  63. ////////////////////////////////////////////////////////////////////////////////////////////////////
  64. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  65. // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  66. template<bool Transposed=false, typename Layout0>
  67. __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
  68. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  69. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  70. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  71. static_assert(decltype(rank(acc_layout))::value == 3);
  72. auto l = acc_layout;
  73. if constexpr (!Transposed) {
  74. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
  75. } else {
  76. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  77. }
  78. } else { // SM80
  79. static_assert(!Transposed);
  80. static_assert(decltype(size<0>(acc_layout))::value == 4);
  81. static_assert(decltype(rank(acc_layout))::value == 3);
  82. auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
  83. return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
  84. }
  85. };
  86. ////////////////////////////////////////////////////////////////////////////////////////////////////
  87. // For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  88. template<typename Layout0>
  89. __forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout0 acc_layout) {
  90. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  91. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  92. static_assert(decltype(rank(acc_layout))::value == 3);
  93. auto l = acc_layout;
  94. return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
  95. };
  96. ////////////////////////////////////////////////////////////////////////////////////////////////////
  97. // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
  98. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
  99. // For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
  100. // For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
  101. template<typename MMA_Traits, typename Layout0>
  102. __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
  103. using X = Underscore;
  104. if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
  105. static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
  106. static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
  107. static_assert(decltype(rank(acc_layout))::value == 3);
  108. static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
  109. if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
  110. auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
  111. return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  112. } else {
  113. static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
  114. static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
  115. static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
  116. auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
  117. // This combines the first two modes (<0, 0> and <0, 1>) into one mode.
  118. // Will require register shuffling later to be correct.
  119. return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
  120. get<1>(acc_layout),
  121. coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
  122. // This combination is right but doesn't work with register shuffling.
  123. // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
  124. // get<1>(acc_layout),
  125. // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
  126. }
  127. } else { // SM80
  128. static_assert(decltype(size<0>(acc_layout))::value == 4);
  129. static_assert(decltype(rank(acc_layout))::value == 3);
  130. constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
  131. static_assert(mma_shape_K == 8 || mma_shape_K == 16);
  132. if constexpr (mma_shape_K == 8) {
  133. return acc_layout;
  134. } else {
  135. auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
  136. return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
  137. }
  138. }
  139. };
  140. ////////////////////////////////////////////////////////////////////////////////////////////////////
  141. template <typename To_type, typename Engine, typename Layout>
  142. __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
  143. using From_type = typename Engine::value_type;
  144. constexpr int numel = decltype(size(tensor))::value;
  145. cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
  146. // HACK: this requires tensor to be "contiguous"
  147. auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
  148. return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
  149. }
  150. ////////////////////////////////////////////////////////////////////////////////////////////////////
  151. template <typename To_type, typename Engine, typename Layout>
  152. __forceinline__ __device__ auto convert_type_safe(Tensor<Engine, Layout> const &tensor) {
  153. using From_type = typename Engine::value_type;
  154. Tensor out = make_fragment_like<To_type>(tensor);
  155. constexpr int FragmentSize = sizeof(From_type) / sizeof(To_type);
  156. static_assert(CUTE_STATIC_V(size<0>(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
  157. Tensor frag = recast<cutlass::Array<From_type, FragmentSize>>(tensor);
  158. Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
  159. static_assert(size(frag) == size(out_frg));
  160. cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
  161. #pragma unroll
  162. for (int i = 0; i < size(frag); ++i) { out_frg(i) = convert_op(frag(i)); }
  163. // Tensor frag_32b = recast<uint32_t>(make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()));
  164. // Tensor out_32b = recast<uint32_t>(out);
  165. // // cute::copy(frag_32b, out_32b);
  166. // #pragma unroll
  167. // for (int i = 0; i < size(frag_32b); ++i) { out_32b[i] = frag_32b[i]; }
  168. return out;
  169. }
  170. ////////////////////////////////////////////////////////////////////////////////////////////////////
  171. // Blocks until all but N previous cp.async.commit_group operations have committed.
  172. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
  173. // (which is equivalent to commit_group then wait_group 0).
  174. // Instead we just call cp.async.wait_group 0, which is slightly faster.
  175. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
  176. template <int N>
  177. CUTE_HOST_DEVICE
  178. void cp_async_wait() {
  179. #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
  180. asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
  181. #endif
  182. }
  183. ////////////////////////////////////////////////////////////////////////////////////////////////////
  184. template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
  185. typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
  186. __forceinline__ __device__ void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
  187. if constexpr (M_slice >= 0) {
  188. static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
  189. static_assert(M_slice < MMA_M);
  190. // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N)
  191. Tensor tCrC_slice = cute::logical_divide(tCrC, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  192. if constexpr (!SwapAB) {
  193. Tensor tCrA_slice = cute::logical_divide(tCrA, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  194. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA_slice, tCrB, tCrC_slice);
  195. } else {
  196. Tensor tCrB_slice = cute::logical_divide(tCrB, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);
  197. gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA, tCrB_slice, tCrC_slice);
  198. }
  199. } else {
  200. constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
  201. // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
  202. if constexpr (Is_RS) {
  203. if constexpr (!SwapAB) {
  204. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  205. } else {
  206. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  207. }
  208. }
  209. warpgroup_fence_operand(tCrC);
  210. warpgroup_arrive();
  211. if constexpr (zero_init) {
  212. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  213. }
  214. // Unroll the K mode manually to set scale D to 1
  215. CUTLASS_PRAGMA_UNROLL
  216. for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
  217. if constexpr (!SwapAB) {
  218. cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
  219. } else {
  220. cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC);
  221. }
  222. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  223. }
  224. warpgroup_commit_batch();
  225. if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
  226. warpgroup_fence_operand(tCrC);
  227. if constexpr (Is_RS) {
  228. if constexpr (!SwapAB) {
  229. warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
  230. } else {
  231. warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));
  232. }
  233. }
  234. }
  235. }
  236. ////////////////////////////////////////////////////////////////////////////////////////////////////
  237. template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
  238. typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  239. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  240. __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
  241. Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
  242. Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
  243. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  244. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  245. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  246. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  247. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  248. // There's no case where !Clear_OOB_K && Clear_OOB_MN
  249. static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
  250. #pragma unroll
  251. for (int m = 0; m < size<1>(S); ++m) {
  252. if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
  253. #pragma unroll
  254. for (int k = 0; k < size<2>(S); ++k) {
  255. if (Is_even_K || predicate_K(k)) {
  256. cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
  257. } else if (Clear_OOB_K) {
  258. cute::clear(D(_, m, k));
  259. }
  260. }
  261. } else if (Clear_OOB_MN) {
  262. cute::clear(D(_, m, _));
  263. }
  264. }
  265. }
  266. ////////////////////////////////////////////////////////////////////////////////////////////////////
  267. // Byte permute and shuffle to match register layout of
  268. // (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
  269. template <typename Fragment>
  270. CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) {
  271. // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits
  272. static_assert(decltype(size<0, 0>(frag))::value == 4);
  273. static_assert(decltype(size<0, 1>(frag))::value == 2);
  274. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  275. static_assert(decltype(stride<0, 1>(frag))::value == 4);
  276. static_assert(sizeof(typename Fragment::value_type) == 1);
  277. int quad_idx = threadIdx.x % 4;
  278. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  279. int selector_upper = lane_03 ? 0x5410 : 0x1054;
  280. int selector_lower = lane_03 ? 0x7632 : 0x3276;
  281. static constexpr int upper_map[4] = {0, 3, 1, 2};
  282. static constexpr int lower_map[4] = {1, 2, 0, 3};
  283. Tensor frag_64b = recast<uint2>(frag); // ((1, 1, 2), MMA_M, MMA_N)
  284. #pragma unroll
  285. for (int i = 0; i < size(frag_64b); ++i) {
  286. uint32_t upper = frag_64b[i].x;
  287. uint32_t lower = frag_64b[i].y;
  288. uint32_t upper0 = lane_03 ? upper : lower;
  289. uint32_t lower0 = lane_03 ? lower : upper;
  290. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  291. lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  292. frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper);
  293. frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower);
  294. }
  295. }
  296. ////////////////////////////////////////////////////////////////////////////////////////////////////
  297. template <typename Fragment>
  298. CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
  299. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  300. static_assert(decltype(size<0, 0>(frag))::value == 2);
  301. static_assert(decltype(size<0, 1>(frag))::value == 2);
  302. static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);
  303. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  304. static_assert(sizeof(typename Fragment::value_type) == 4);
  305. Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  306. #pragma unroll
  307. for (int mi = 0; mi < size<1>(frag_64b); ++mi) {
  308. #pragma unroll
  309. for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {
  310. cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));
  311. }
  312. }
  313. }
  314. ////////////////////////////////////////////////////////////////////////////////////////////////////
  315. template <typename Fragment>
  316. CUTLASS_DEVICE void permute_output_fp8(Fragment &out) {
  317. // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
  318. static_assert(decltype(size<0, 0>(out))::value == 2);
  319. static_assert(decltype(size<0, 1>(out))::value == 2);
  320. static_assert(decltype(size<0, 2>(out))::value % 2 == 0);
  321. static_assert(decltype(stride<0, 0>(out))::value == 1);
  322. static_assert(sizeof(typename Fragment::value_type) == 4);
  323. Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N))
  324. #pragma unroll
  325. for (int mi = 0; mi < size<1>(frag); ++mi) {
  326. #pragma unroll
  327. for (int j = 0; j < size<0, 1>(frag); ++j) {
  328. #pragma unroll
  329. for (int i = 0; i < size<0, 2>(frag) / 2; ++i) {
  330. cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi));
  331. }
  332. }
  333. }
  334. }
  335. ////////////////////////////////////////////////////////////////////////////////////////////////////
  336. template <typename Fragment>
  337. CUTLASS_DEVICE void permute_output_fp8_fp16(Fragment &frag) {
  338. // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits
  339. static_assert(decltype(size<0, 0>(frag))::value == 2);
  340. static_assert(decltype(size<0, 1>(frag))::value == 2);
  341. static_assert(decltype(stride<0, 0>(frag))::value == 1);
  342. static_assert(sizeof(typename Fragment::value_type) == 2);
  343. int quad_idx = threadIdx.x % 4;
  344. bool lane_03 = quad_idx == 0 || quad_idx == 3;
  345. static constexpr int upper_map[4] = {0, 2, 3, 1};
  346. static constexpr int lower_map[4] = {2, 0, 1, 3};
  347. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  348. Tensor frag_32b = group_modes<1, 3>(recast<uint32_t>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
  349. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_32b); }
  350. #pragma unroll
  351. for (int mi = 0; mi < size<1>(frag_32b); ++mi) {
  352. #pragma unroll
  353. for (int j = 0; j < size<0, 1>(frag_32b); ++j) {
  354. #pragma unroll
  355. for (int i = 0; i < size<0, 2>(frag_32b) / 2; ++i) {
  356. // cutlass::swap(frag_32b(make_coord(_0{}, j, 2 * i), mi), frag_32b(make_coord(_0{}, j, 2 * i + 1), mi));
  357. uint32_t upper = frag_32b(make_coord(_0{}, j, 2 * i), mi);
  358. uint32_t lower = frag_32b(make_coord(_0{}, j, 2 * i + 1), mi);
  359. uint32_t upper0 = lane_03 ? upper : lower;
  360. uint32_t lower0 = lane_03 ? lower : upper;
  361. upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);
  362. lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);
  363. frag_32b(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0;
  364. frag_32b(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0;
  365. }
  366. }
  367. }
  368. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }
  369. }
  370. ////////////////////////////////////////////////////////////////////////////////////////////////////
  371. template <typename Engine, typename Layout>
  372. __forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, float const softcap){
  373. #pragma unroll
  374. for (int i = 0; i < size(tensor); ++i) {
  375. tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
  376. }
  377. }
  378. template <typename Engine, typename Layout>
  379. __forceinline__ __device__ auto calculate_dtanh(Tensor<Engine, Layout> &tensor){
  380. Tensor out = make_fragment_like<float>(tensor);
  381. #pragma unroll
  382. for (int i = 0; i < size(tensor); ++i) {
  383. out(i) = 1.f - (tensor(i) * tensor(i));
  384. }
  385. return out;
  386. }
  387. ////////////////////////////////////////////////////////////////////////////////////////////////////
  388. } // namespace flash