broadcast_load_epilogue_c3x.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. /***************************************************************************************************
  2. * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
  3. *reserved. SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice,
  9. *this list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holder nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  23. *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. *POSSIBILITY OF SUCH DAMAGE.
  30. *
  31. **************************************************************************************************/
  32. //
  33. // This file is a modified excerpt of
  34. // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
  35. // from https://github.com/NVIDIA/cutlass v3.5.0
  36. // It has been modified to support either row/column or scalar broadcasting
  37. // where the tensor being loaded from is always passed in via a device pointer.
  38. // This lets one compiled kernel handle all cases of per-tensor or
  39. // per-channel/per-token quantization.
  40. //
  41. // This interface also allows the scales to be passed in as tensors that
  42. // consistently reside on the device, which avoids an issue with a previous
  43. // implementation where scalars needed to be on the CPU since they
  44. // were passed in via float values. This created a potential performance hazard
  45. // if scales were initially on the device, and caused torch.compile graphs
  46. // breaks when moving scales to the CPU.
  47. //
  48. #pragma once
  49. // Turn off clang-format for the entire file to keep it close to upstream
  50. // clang-format off
  51. #include "cutlass/cutlass.h"
  52. #include "cutlass/arch/barrier.h"
  53. #include "cute/tensor.hpp"
  54. #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
  55. namespace cutlass::epilogue::fusion {
  56. using namespace cute;
  57. using namespace detail;
  58. // Row vector broadcast
  59. template<
  60. // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
  61. // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
  62. int Stages,
  63. class CtaTileShapeMNK,
  64. class Element,
  65. class StrideMNL = Stride<_0,_1,_0>,
  66. int Alignment = 128 / sizeof_bits_v<Element>
  67. >
  68. struct Sm90RowOrScalarBroadcast {
  69. static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
  70. static_assert(
  71. (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
  72. (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
  73. // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
  74. struct SharedStorage {
  75. alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
  76. };
  77. // This struct has been modified to have a bool indicating that ptr_row is a
  78. // scalar that must be broadcast, instead of containing a scalar that is
  79. // valid if ptr_row is null.
  80. struct Arguments {
  81. Element const* ptr_row = nullptr;
  82. bool row_broadcast = true;
  83. StrideMNL dRow = {};
  84. };
  85. using Params = Arguments;
  86. template <class ProblemShape>
  87. static constexpr Params
  88. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  89. return args;
  90. }
  91. template <class ProblemShape>
  92. static size_t
  93. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  94. return 0;
  95. }
  96. template <class ProblemShape>
  97. static cutlass::Status
  98. initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
  99. CudaHostAdapter* cuda_adapter = nullptr) {
  100. return cutlass::Status::kSuccess;
  101. }
  102. CUTLASS_HOST_DEVICE
  103. Sm90RowOrScalarBroadcast() { }
  104. CUTLASS_HOST_DEVICE
  105. Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  106. : params(params),
  107. smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
  108. Params params;
  109. Element* smem_row;
  110. CUTLASS_DEVICE bool
  111. is_producer_load_needed() const {
  112. return true;
  113. }
  114. CUTLASS_DEVICE bool
  115. is_C_load_needed() const {
  116. return false;
  117. }
  118. CUTLASS_DEVICE bool
  119. is_zero() const {
  120. return (!params.row_broadcast && *(params.ptr_row) == Element(0));
  121. }
  122. template <int EpiTiles, class GTensor, class STensor>
  123. struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
  124. CUTLASS_DEVICE
  125. ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
  126. : gRow(cute::forward<GTensor>(gRow)),
  127. sRow(cute::forward<STensor>(sRow)),
  128. params(params) {}
  129. GTensor gRow; // (CTA_M,CTA_N)
  130. STensor sRow; // (CTA_M,CTA_N,PIPE)
  131. Params const& params;
  132. CUTLASS_DEVICE void
  133. begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
  134. if (!params.row_broadcast) {
  135. return;
  136. }
  137. if (issue_tma_load) {
  138. // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
  139. constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
  140. cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
  141. // Issue the TMA bulk copy
  142. auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
  143. // Filter so we don't issue redundant copies over stride-0 modes
  144. int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
  145. copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
  146. }
  147. }
  148. };
  149. template <class... Args>
  150. CUTLASS_DEVICE auto
  151. get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
  152. auto [M, N, K, L] = args.problem_shape_mnkl;
  153. auto [m, n, k, l] = args.tile_coord_mnkl;
  154. Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
  155. Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
  156. Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
  157. make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
  158. make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
  159. constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
  160. return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
  161. cute::move(gRow), cute::move(sRow), params);
  162. }
  163. template <int EpiTiles, class RTensor, class STensor>
  164. struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
  165. CUTLASS_DEVICE
  166. ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
  167. : tCrRow(cute::forward<RTensor>(tCrRow)),
  168. tCsRow(cute::forward<STensor>(tCsRow)),
  169. params(params) {}
  170. RTensor tCrRow; // (CPY,CPY_M,CPY_N)
  171. STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
  172. Params const& params;
  173. CUTLASS_DEVICE void
  174. previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
  175. if (!params.row_broadcast) {
  176. fill(tCrRow, *(params.ptr_row));
  177. return;
  178. }
  179. if (epi_m == 0) { // Assumes M-major subtile loop
  180. // Filter so we don't issue redundant copies over stride-0 modes
  181. // (only works if 0-strides are in same location, which is by construction)
  182. int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
  183. copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
  184. }
  185. }
  186. template <typename ElementAccumulator, int FragmentSize>
  187. CUTLASS_DEVICE Array<Element, FragmentSize>
  188. visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
  189. Array<Element, FragmentSize> frg_row;
  190. CUTLASS_PRAGMA_UNROLL
  191. for (int i = 0; i < FragmentSize; ++i) {
  192. frg_row[i] = tCrRow(epi_v * FragmentSize + i);
  193. }
  194. return frg_row;
  195. }
  196. };
  197. template <
  198. bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
  199. class... Args
  200. >
  201. CUTLASS_DEVICE auto
  202. get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
  203. Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
  204. make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
  205. make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
  206. Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
  207. sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
  208. Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
  209. constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
  210. return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
  211. cute::move(tCrRow), cute::move(tCsRow), params);
  212. }
  213. };
  214. /////////////////////////////////////////////////////////////////////////////////////////////////
  215. // Column vector broadcast
  216. template<
  217. int Stages,
  218. class CtaTileShapeMNK,
  219. class Element,
  220. class StrideMNL = Stride<_1,_0,_0>,
  221. int Alignment = 128 / sizeof_bits_v<Element>
  222. >
  223. struct Sm90ColOrScalarBroadcast {
  224. static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
  225. static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
  226. static_assert(
  227. (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
  228. (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
  229. // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
  230. struct SharedStorage { };
  231. // This struct has been modified to have a bool indicating that ptr_col is a
  232. // scalar that must be broadcast, instead of containing a scalar that is
  233. // valid if ptr_col is null.
  234. struct Arguments {
  235. Element const* ptr_col = nullptr;
  236. bool col_broadcast = true;
  237. StrideMNL dCol = {};
  238. };
  239. using Params = Arguments;
  240. template <class ProblemShape>
  241. static constexpr Params
  242. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  243. return args;
  244. }
  245. template <class ProblemShape>
  246. static size_t
  247. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  248. return 0;
  249. }
  250. template <class ProblemShape>
  251. static cutlass::Status
  252. initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
  253. CudaHostAdapter* cuda_adapter = nullptr) {
  254. return cutlass::Status::kSuccess;
  255. }
  256. CUTLASS_DEVICE bool
  257. is_producer_load_needed() const {
  258. return false;
  259. }
  260. CUTLASS_DEVICE bool
  261. is_C_load_needed() const {
  262. return false;
  263. }
  264. CUTLASS_DEVICE bool
  265. is_zero() const {
  266. return (!params.col_broadcast && *(params.ptr_col) == Element(0));
  267. }
  268. CUTLASS_HOST_DEVICE
  269. Sm90ColOrScalarBroadcast() { }
  270. CUTLASS_HOST_DEVICE
  271. Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  272. : params(params) { }
  273. Params params;
  274. template <class... Args>
  275. CUTLASS_DEVICE auto
  276. get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
  277. return EmptyProducerLoadCallbacks{};
  278. }
  279. template<class GTensor, class RTensor>
  280. struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
  281. CUTLASS_DEVICE
  282. ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
  283. : tCgCol(cute::forward<GTensor>(tCgCol)),
  284. tCrCol(cute::forward<RTensor>(tCrCol)),
  285. params(params) {}
  286. GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  287. RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  288. Params const& params;
  289. CUTLASS_DEVICE void
  290. begin() {
  291. if (!params.col_broadcast) {
  292. fill(tCrCol, *(params.ptr_col));
  293. return;
  294. }
  295. // Filter so we don't issue redundant copies over stride-0 modes
  296. // (only works if 0-strides are in same location, which is by construction)
  297. copy_aligned(filter(tCgCol), filter(tCrCol));
  298. }
  299. template <typename ElementAccumulator, int FragmentSize>
  300. CUTLASS_DEVICE Array<Element, FragmentSize>
  301. visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
  302. Array<Element, FragmentSize> frg_col;
  303. Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
  304. CUTLASS_PRAGMA_UNROLL
  305. for (int i = 0; i < FragmentSize; ++i) {
  306. frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
  307. }
  308. return frg_col;
  309. }
  310. };
  311. template <
  312. bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
  313. class... Args
  314. >
  315. CUTLASS_DEVICE auto
  316. get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
  317. auto [M, N, K, L] = args.problem_shape_mnkl;
  318. Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
  319. Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  320. mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
  321. Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
  322. return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
  323. cute::move(tCgCol), cute::move(tCrCol), params);
  324. }
  325. };
  326. }