broadcast_load_epilogue_c2x.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /***************************************************************************************************
  2. * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
  3. *reserved. SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice,
  9. *this list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holder nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  23. *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. *POSSIBILITY OF SUCH DAMAGE.
  30. *
  31. **************************************************************************************************/
  32. //
  33. // This file is a modified excerpt of
  34. // include/cutlass/epilogue/fusion/visitor_load.hpp from
  35. // https://github.com/NVIDIA/cutlass v3.5.0
  36. // It has been modified to support either
  37. // row/column or scalar broadcasting where the tensor being loaded from is
  38. // always passed in via a device pointer. This lets one compiled kernel handle
  39. // all cases of per-tensor or per-channel/per-token quantization.
  40. //
  41. // This interface also allows the scales to be passed in as tensors that
  42. // consistently reside on the device, which avoids an issue with a previous
  43. // implementation where scalars needed to be on the CPU since they
  44. // were passed in via float values. This created a potential performance hazard
  45. // if scales were initially on the device, and caused torch.compile graph
  46. // breaks when moving scales to the CPU.
  47. //
  48. #pragma once
  49. // Turn off clang-format for the entire file to keep it close to upstream
  50. // clang-format off
  51. #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
  52. #include "cute/tensor.hpp"
  53. namespace cutlass::epilogue::threadblock {
  54. using namespace cute;
  55. using namespace detail;
  56. template<
  57. class ThreadMap,
  58. class Element,
  59. class StrideMNL
  60. >
  61. struct VisitorRowOrScalarBroadcast {
  62. // This struct has been modified to have a bool indicating that ptr_row is a
  63. // scalar that must be broadcast.
  64. struct Arguments {
  65. Element const* ptr_row = nullptr;
  66. bool row_broadcast = true;
  67. StrideMNL dRow = {};
  68. };
  69. using Params = Arguments;
  70. template <class ProblemShape>
  71. static constexpr Params
  72. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  73. return args;
  74. }
  75. template <class ProblemShape>
  76. static size_t
  77. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  78. return 0;
  79. }
  80. struct SharedStorage {};
  81. // Global load type
  82. static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
  83. using VecType = uint_bit_t<cute::min(128, vec_bits)>;
  84. static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
  85. CUTLASS_HOST_DEVICE
  86. VisitorRowOrScalarBroadcast() { }
  87. CUTLASS_HOST_DEVICE
  88. VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  89. : params_ptr(&params) { }
  90. Params const* params_ptr;
  91. template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  92. struct Callbacks : EmptyCallbacks {
  93. CUTLASS_DEVICE
  94. Callbacks(
  95. GTensor&& tC_gRow,
  96. RTensor&& tC_rRow,
  97. CTensor&& tC_cRow,
  98. ProblemShape problem_shape,
  99. Params const* params_ptr
  100. ):
  101. tC_gRow(cute::forward<GTensor>(tC_gRow)),
  102. tC_rRow(cute::forward<RTensor>(tC_rRow)),
  103. tC_cRow(cute::forward<CTensor>(tC_cRow)),
  104. n(get<1>(problem_shape)),
  105. params_ptr(params_ptr) { }
  106. GTensor tC_gRow;
  107. RTensor tC_rRow;
  108. CTensor tC_cRow;
  109. Params const* params_ptr;
  110. int n;
  111. // This function is modified from VisitorRowBroadcast
  112. CUTLASS_DEVICE void
  113. begin_epilogue() {
  114. clear(tC_rRow);
  115. auto src_v = filter(tC_gRow);
  116. auto coord_v = filter(tC_cRow);
  117. auto dst_v = filter(tC_rRow);
  118. if (params_ptr->row_broadcast) {
  119. // In this case we are loading from a row vector and broadcasting
  120. CUTLASS_PRAGMA_UNROLL
  121. for (int i = 0; i < size(src_v); ++i) {
  122. bool guard = get<1>(coord_v(i)) < n;
  123. cutlass::arch::global_load<VecType, sizeof(VecType)>(
  124. dst_v(i), (void const*)&src_v(i), guard);
  125. }
  126. } else {
  127. // In this case we are loading from a scalar and broadcasting
  128. VecType filled_vec;
  129. CUTLASS_PRAGMA_UNROLL
  130. for (int i = 0; i < VecLength; i++) {
  131. reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
  132. }
  133. CUTLASS_PRAGMA_UNROLL
  134. for (int i = 0; i < size(src_v); ++i) {
  135. if (get<1>(coord_v(i)) < n) {
  136. dst_v(i) = filled_vec;
  137. }
  138. }
  139. }
  140. }
  141. template <class ElementAccumulator, int FragmentSize>
  142. CUTLASS_DEVICE auto // returns an Array
  143. visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
  144. Array<ElementAccumulator, FragmentSize> const& frg_acc) {
  145. Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
  146. return rRow_frg(column_idx);
  147. }
  148. };
  149. template <class ProblemShape>
  150. CUTLASS_DEVICE auto
  151. get_callbacks(
  152. gemm::GemmCoord threadblock_tile_offset,
  153. int thread_idx,
  154. ProblemShape problem_shape
  155. ) {
  156. Tensor mRow = make_tensor(
  157. make_gmem_ptr(params_ptr->ptr_row),
  158. problem_shape,
  159. params_ptr->dRow);
  160. // VECTOR, FRAGMENT_COLUMN
  161. Tensor tC_gRow = recast<VecType>(
  162. ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
  163. )(_,_,_0{},_0{},_0{},_0{});
  164. Tensor tC_rRow = make_tensor_like(tC_gRow);
  165. // Generate the pred tensor
  166. Tensor cRow = make_identity_tensor(mRow.shape());
  167. Tensor tC_cRow = outer_partition(
  168. ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
  169. Shape<Int<VecLength>>{},
  170. (_0{})
  171. );
  172. return Callbacks<
  173. decltype(tC_gRow), decltype(tC_rRow),
  174. decltype(tC_cRow), ProblemShape>(
  175. cute::move(tC_gRow),
  176. cute::move(tC_rRow),
  177. cute::move(tC_cRow),
  178. problem_shape,
  179. params_ptr
  180. );
  181. }
  182. };
  183. /////////////////////////////////////////////////////////////////////////////////////////////////
  184. // Column vector broadcast
  185. template<
  186. class ThreadMap,
  187. class Element,
  188. class StrideMNL = Stride<_1,_0,_0>
  189. >
  190. struct VisitorColOrScalarBroadcast {
  191. // This struct has been modified to have a bool indicating that ptr_col is a
  192. // scalar that must be broadcast.
  193. struct Arguments {
  194. Element const* ptr_col = nullptr;
  195. bool col_broadcast = true;
  196. StrideMNL dCol = {};
  197. };
  198. using Params = Arguments;
  199. template <class ProblemShape>
  200. static constexpr Params
  201. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  202. return args;
  203. }
  204. template <class ProblemShape>
  205. static size_t
  206. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  207. return 0;
  208. }
  209. struct SharedStorage { };
  210. CUTLASS_HOST_DEVICE
  211. VisitorColOrScalarBroadcast() { }
  212. CUTLASS_HOST_DEVICE
  213. VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  214. : params_ptr(&params) { }
  215. Params const* params_ptr;
  216. template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  217. struct Callbacks : EmptyCallbacks {
  218. CUTLASS_DEVICE
  219. Callbacks(
  220. GTensor&& tC_gCol,
  221. RTensor&& tC_rCol,
  222. CTensor&& tC_cCol,
  223. ProblemShape problem_shape,
  224. Params const* params_ptr
  225. ):
  226. tC_gCol(cute::forward<GTensor>(tC_gCol)),
  227. tC_rCol(cute::forward<RTensor>(tC_rCol)),
  228. tC_cCol(cute::forward<CTensor>(tC_cCol)),
  229. m(get<0>(problem_shape)),
  230. params_ptr(params_ptr) { }
  231. GTensor tC_gCol;
  232. RTensor tC_rCol;
  233. CTensor tC_cCol;
  234. Params const* params_ptr;
  235. int m;
  236. // This function is modified from VisitorColBroadcast
  237. CUTLASS_DEVICE void
  238. begin_epilogue() {
  239. clear(tC_rCol);
  240. Tensor pred = make_tensor<bool>(shape(tC_gCol));
  241. CUTLASS_PRAGMA_UNROLL
  242. for (int i = 0; i < size(pred); ++i) {
  243. pred(i) = get<0>(tC_cCol(i)) < m;
  244. }
  245. if (params_ptr->col_broadcast) {
  246. // In this case we are loading from a column vector and broadcasting
  247. copy_if(pred, tC_gCol, tC_rCol);
  248. } else {
  249. // In this case we are loading from a scalar and broadcasting
  250. auto dst_v = filter(tC_rCol);
  251. CUTLASS_PRAGMA_UNROLL
  252. for (int i = 0; i < size(dst_v); ++i) {
  253. if (pred(i)) {
  254. dst_v(i) = *(params_ptr->ptr_col);
  255. }
  256. }
  257. }
  258. }
  259. template <class ElementAccumulator, int FragmentSize>
  260. CUTLASS_DEVICE auto // returns an Array
  261. visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
  262. Array<ElementAccumulator, FragmentSize> const& frg_acc) {
  263. Array<Element, FragmentSize> frg_col;
  264. frg_col.fill(tC_rCol(row_idx,iter_idx));
  265. return frg_col;
  266. }
  267. };
  268. template <class ProblemShape>
  269. CUTLASS_DEVICE auto
  270. get_callbacks(
  271. gemm::GemmCoord threadblock_tile_offset,
  272. int thread_idx,
  273. ProblemShape problem_shape
  274. ) {
  275. Tensor mCol = make_tensor(
  276. make_gmem_ptr(params_ptr->ptr_col),
  277. problem_shape,
  278. params_ptr->dCol);
  279. // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
  280. Tensor tC_gCol = group_modes<1,4>(
  281. ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
  282. Tensor tC_rCol = make_tensor_like(tC_gCol);
  283. // Generate the pred tensor
  284. Tensor cCol = make_identity_tensor(mCol.shape());
  285. Tensor tC_cCol = group_modes<1,4>(
  286. ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
  287. return Callbacks<
  288. decltype(tC_gCol), decltype(tC_rCol),
  289. decltype(tC_cCol), ProblemShape>(
  290. cute::move(tC_gCol),
  291. cute::move(tC_rCol),
  292. cute::move(tC_cCol),
  293. problem_shape,
  294. params_ptr
  295. );
  296. }
  297. };
  298. }