broadcast_load_epilogue_c3x.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. /***************************************************************************************************
  2. * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
  3. *reserved. SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice,
  9. *this list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holder nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  23. *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. *POSSIBILITY OF SUCH DAMAGE.
  30. *
  31. **************************************************************************************************/
  32. //
  33. // This file is a modified excerpt of
  34. // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
  35. // from https://github.com/NVIDIA/cutlass v3.5.0
  36. // It has been modified to support either row/column or scalar broadcasting
  37. // where the tensor being loaded from is always passed in via a device pointer.
  38. // This lets one compiled kernel handle all cases of per-tensor or
  39. // per-channel/per-token quantization.
  40. //
  41. // This interface also allows the scales to be passed in as tensors that
  42. // consistently reside on the device, which avoids an issue with a previous
  43. // implementation where scalars needed to be on the CPU since they
  44. // were passed in via float values. This created a potential performance hazard
  45. // if scales were initially on the device, and caused torch.compile graphs
  46. // breaks when moving scales to the CPU.
  47. //
  48. #pragma once
  49. // Turn off clang-format for the entire file to keep it close to upstream
  50. // clang-format off
  51. #include "cutlass/cutlass.h"
  52. #include "cutlass/arch/barrier.h"
  53. #include "cute/tensor.hpp"
  54. #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
  55. namespace cutlass::epilogue::fusion {
  56. using namespace cute;
  57. using namespace detail;
  58. // Row vector broadcast
  59. template<
  60. int Stages,
  61. class CtaTileShapeMNK,
  62. class Element,
  63. class StrideMNL = Stride<_0,_1,_0>,
  64. int Alignment = 128 / sizeof_bits_v<Element>
  65. >
  66. struct Sm90RowOrScalarBroadcast {
  67. static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
  68. static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
  69. static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
  70. struct SharedStorage {
  71. array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
  72. };
  73. // This struct has been modified to have a bool indicating that ptr_row is a
  74. // scalar that must be broadcast, instead of containing a scalar that is
  75. // valid if ptr_row is null.
  76. struct Arguments {
  77. Element const* ptr_row = nullptr;
  78. bool row_broadcast = true;
  79. StrideMNL dRow = {};
  80. };
  81. using Params = Arguments;
  82. template <class ProblemShape>
  83. static constexpr Params
  84. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  85. return args;
  86. }
  87. template <class ProblemShape>
  88. static bool
  89. can_implement(ProblemShape const& problem_shape, Arguments const& args) {
  90. return true;
  91. }
  92. template <class ProblemShape>
  93. static size_t
  94. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  95. return 0;
  96. }
  97. template <class ProblemShape>
  98. static cutlass::Status
  99. initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
  100. CudaHostAdapter* cuda_adapter = nullptr) {
  101. return cutlass::Status::kSuccess;
  102. }
  103. CUTLASS_HOST_DEVICE
  104. Sm90RowOrScalarBroadcast() { }
  105. CUTLASS_HOST_DEVICE
  106. Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  107. : params(params)
  108. , smem(const_cast<Element*>(shared_storage.smem.data())) { }
  109. Params params;
  110. Element *smem = nullptr;
  111. CUTLASS_DEVICE bool
  112. is_producer_load_needed() const {
  113. return false;
  114. }
  115. CUTLASS_DEVICE bool
  116. is_C_load_needed() const {
  117. return false;
  118. }
  119. CUTLASS_DEVICE bool
  120. is_zero() const {
  121. return (!params.row_broadcast && *(params.ptr_row) == Element(0));
  122. }
  123. template <class... Args>
  124. CUTLASS_DEVICE auto
  125. get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
  126. return EmptyProducerLoadCallbacks{};
  127. }
  128. template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
  129. struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
  130. CUTLASS_DEVICE
  131. ConsumerStoreCallbacks(
  132. GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
  133. GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
  134. SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
  135. CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
  136. : tGS_gRow(tGS_gRow_)
  137. , tGS_sRow(tGS_sRow_)
  138. , tGS_cRow(tGS_cRow_)
  139. , tiled_G2S(tiled_g2s_)
  140. , tSR_sRow(tSR_sRow_)
  141. , tSR_rRow(tSR_rRow_)
  142. , tCcRow(tCcRow_)
  143. , residue_tCcRow(residue_tCcRow_)
  144. , params(params_) {}
  145. GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
  146. GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
  147. GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
  148. Tiled_G2S tiled_G2S;
  149. SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  150. SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  151. CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  152. ThrResidue residue_tCcRow; // (m, n)
  153. ThrNum thr_num;
  154. Params const& params;
  155. CUTLASS_DEVICE void
  156. begin() {
  157. if (!params.row_broadcast) {
  158. fill(tSR_rRow, *(params.ptr_row));
  159. return;
  160. }
  161. auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
  162. Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
  163. Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
  164. Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
  165. for (int i = 0; i < size(tGS_gRow_flt); ++i) {
  166. if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
  167. continue; // OOB of SMEM,
  168. }
  169. if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
  170. tGS_sRow_flt(i) = tGS_gRow_flt(i);
  171. }
  172. else {
  173. tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
  174. }
  175. }
  176. synchronize();
  177. }
  178. CUTLASS_DEVICE void
  179. begin_loop(int epi_m, int epi_n) {
  180. if (epi_m == 0) { // Assumes M-major subtile loop
  181. if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
  182. Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
  183. Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
  184. copy(tSR_sRow_flt, tSR_rRow_flt);
  185. }
  186. }
  187. template <typename ElementAccumulator, int FragmentSize>
  188. CUTLASS_DEVICE Array<Element, FragmentSize>
  189. visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
  190. Array<Element, FragmentSize> frg_row;
  191. CUTLASS_PRAGMA_UNROLL
  192. for (int i = 0; i < FragmentSize; ++i) {
  193. frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
  194. }
  195. return frg_row;
  196. }
  197. };
  198. template <
  199. bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
  200. class... Args
  201. >
  202. CUTLASS_DEVICE auto
  203. get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
  204. auto [M, N, K, L] = args.problem_shape_mnkl;
  205. auto [m, n, k, l] = args.tile_coord_mnkl;
  206. using ThreadCount = decltype(size(args.tiled_copy));
  207. Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
  208. Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
  209. Tensor sRow = make_tensor(make_smem_ptr(smem),
  210. make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
  211. //// G2S: Gmem to Smem
  212. auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  213. Layout< Shape<_1, ThreadCount>,
  214. Stride<_0, _1>>{},
  215. Layout<_1>{});
  216. auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
  217. Tensor tGS_gRow = thr_g2s.partition_S(gRow);
  218. Tensor tGS_sRow = thr_g2s.partition_D(sRow);
  219. //// G2S: Coord
  220. auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
  221. Tensor tGS_cRow = thr_g2s.partition_S(cRow);
  222. //// S2R: Smem to Reg
  223. Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
  224. Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
  225. return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
  226. tGS_gRow,
  227. tGS_sRow,
  228. tGS_cRow, tiled_g2s,
  229. tSR_sRow,
  230. tSR_rRow,
  231. args.tCcD,
  232. args.residue_cD,
  233. ThreadCount{},
  234. params);
  235. }
  236. };
  237. /////////////////////////////////////////////////////////////////////////////////////////////////
  238. // Column vector broadcast
  239. template<
  240. int Stages,
  241. class CtaTileShapeMNK,
  242. class Element,
  243. class StrideMNL = Stride<_1,_0,_0>,
  244. int Alignment = 128 / sizeof_bits_v<Element>
  245. >
  246. struct Sm90ColOrScalarBroadcast {
  247. static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
  248. static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
  249. static_assert(
  250. (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
  251. (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
  252. // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
  253. struct SharedStorage { };
  254. // This struct has been modified to have a bool indicating that ptr_col is a
  255. // scalar that must be broadcast, instead of containing a scalar that is
  256. // valid if ptr_col is null.
  257. struct Arguments {
  258. Element const* ptr_col = nullptr;
  259. bool col_broadcast = true;
  260. StrideMNL dCol = {};
  261. };
  262. using Params = Arguments;
  263. template <class ProblemShape>
  264. static constexpr Params
  265. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  266. return args;
  267. }
  268. template <class ProblemShape>
  269. static bool
  270. can_implement(ProblemShape const& problem_shape, Arguments const& args) {
  271. return true;
  272. }
  273. template <class ProblemShape>
  274. static size_t
  275. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  276. return 0;
  277. }
  278. template <class ProblemShape>
  279. static cutlass::Status
  280. initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
  281. CudaHostAdapter* cuda_adapter = nullptr) {
  282. return cutlass::Status::kSuccess;
  283. }
  284. CUTLASS_DEVICE bool
  285. is_producer_load_needed() const {
  286. return false;
  287. }
  288. CUTLASS_DEVICE bool
  289. is_C_load_needed() const {
  290. return false;
  291. }
  292. CUTLASS_DEVICE bool
  293. is_zero() const {
  294. return (!params.col_broadcast && *(params.ptr_col) == Element(0));
  295. }
  296. CUTLASS_HOST_DEVICE
  297. Sm90ColOrScalarBroadcast() { }
  298. CUTLASS_HOST_DEVICE
  299. Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  300. : params(params) { }
  301. Params params;
  302. template <class... Args>
  303. CUTLASS_DEVICE auto
  304. get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
  305. return EmptyProducerLoadCallbacks{};
  306. }
  307. template<class GTensor, class RTensor, class CTensor, class ProblemShape>
  308. struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
  309. CUTLASS_DEVICE
  310. ConsumerStoreCallbacks(
  311. GTensor&& tCgCol,
  312. RTensor&& tCrCol,
  313. CTensor&& tCcCol,
  314. ProblemShape problem_shape,
  315. Params const& params
  316. ):
  317. tCgCol(cute::forward<GTensor>(tCgCol)),
  318. tCrCol(cute::forward<RTensor>(tCrCol)),
  319. tCcCol(cute::forward<CTensor>(tCcCol)),
  320. m(get<0>(problem_shape)),
  321. params(params) {}
  322. GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  323. RTensor tCrCol;
  324. CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  325. Params const& params;
  326. int m;
  327. CUTLASS_DEVICE void
  328. begin() {
  329. Tensor pred = make_tensor<bool>(shape(tCgCol));
  330. CUTLASS_PRAGMA_UNROLL
  331. for (int i = 0; i < size(pred); ++i) {
  332. pred(i) = get<0>(tCcCol(i)) < m;
  333. }
  334. if (!params.col_broadcast) {
  335. fill(tCrCol, *(params.ptr_col));
  336. return;
  337. }
  338. // Filter so we don't issue redundant copies over stride-0 modes
  339. // (only works if 0-strides are in same location, which is by construction)
  340. copy_if(pred, filter(tCgCol), filter(tCrCol));
  341. }
  342. template <typename ElementAccumulator, int FragmentSize>
  343. CUTLASS_DEVICE Array<Element, FragmentSize>
  344. visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
  345. Array<Element, FragmentSize> frg_col;
  346. Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
  347. CUTLASS_PRAGMA_UNROLL
  348. for (int i = 0; i < FragmentSize; ++i) {
  349. frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
  350. }
  351. return frg_col;
  352. }
  353. };
  354. template <
  355. bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
  356. class... Args
  357. >
  358. CUTLASS_DEVICE auto
  359. get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
  360. auto [M, N, K, L] = args.problem_shape_mnkl;
  361. Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
  362. Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  363. mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
  364. Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  365. // Generate an identity tensor matching the shape of the global tensor and
  366. // partition the same way, this will be used to generate the predicate
  367. // tensor for loading
  368. Tensor cCol = make_identity_tensor(mCol.shape());
  369. Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  370. cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
  371. return ConsumerStoreCallbacks(
  372. cute::move(tCgCol),
  373. cute::move(tCrCol),
  374. cute::move(tCcCol),
  375. args.problem_shape_mnkl,
  376. params
  377. );
  378. }
  379. };
  380. }