broadcast_load_epilogue_c2x.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. /***************************************************************************************************
  2. * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
  3. *reserved. SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice,
  9. *this list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holder nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  23. *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. *POSSIBILITY OF SUCH DAMAGE.
  30. *
  31. **************************************************************************************************/
  32. //
  33. // This file is a modified excerpt of
  34. // include/cutlass/epilogue/fusion/visitor_load.hpp from
  35. // https://github.com/NVIDIA/cutlass v3.5.0
  36. // It has been modified to support either
  37. // row/column or scalar broadcasting where the tensor being loaded from is
  38. // always passed in via a device pointer. This lets one compiled kernel handle
  39. // all cases of per-tensor or per-channel/per-token quantization.
  40. //
  41. // This interface also allows the scales to be passed in as tensors that
  42. // consistently reside on the device, which avoids an issue with a previous
  43. // implementation where scalars needed to be on the CPU since they
  44. // were passed in via float values. This created a potential performance hazard
  45. // if scales were initially on the device, and caused torch.compile graph
  46. // breaks when moving scales to the CPU.
  47. //
  48. #pragma once
  49. // Turn off clang-format for the entire file to keep it close to upstream
  50. // clang-format off
  51. #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
  52. #include "cute/tensor.hpp"
  53. namespace cutlass::epilogue::threadblock {
  54. using namespace cute;
  55. using namespace detail;
  56. template<
  57. class ThreadMap,
  58. class Element,
  59. class StrideMNL
  60. >
  61. struct VisitorRowOrScalarBroadcast {
  62. // This struct has been modified to have a bool indicating that ptr_row is a
  63. // scalar that must be broadcast.
  64. struct Arguments {
  65. Element const* ptr_row = nullptr;
  66. bool row_broadcast = true;
  67. StrideMNL dRow = {};
  68. };
  69. using Params = Arguments;
  70. template <class ProblemShape>
  71. static constexpr Params
  72. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  73. return args;
  74. }
  75. template <class ProblemShape>
  76. static size_t
  77. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  78. return 0;
  79. }
  80. struct SharedStorage {};
  81. // Global load type
  82. static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
  83. using VecType = uint_bit_t<cute::min(128, vec_bits)>;
  84. static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
  85. CUTLASS_HOST_DEVICE
  86. VisitorRowOrScalarBroadcast() { }
  87. CUTLASS_HOST_DEVICE
  88. VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  89. : params_ptr(&params) { }
  90. Params const* params_ptr;
  91. template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  92. struct Callbacks : EmptyCallbacks {
  93. CUTLASS_DEVICE
  94. Callbacks(
  95. GTensor&& tC_gRow,
  96. RTensor&& tC_rRow,
  97. CTensor&& tC_cRow,
  98. ProblemShape problem_shape,
  99. Params const* params_ptr
  100. ):
  101. tC_gRow(cute::forward<GTensor>(tC_gRow)),
  102. tC_rRow(cute::forward<RTensor>(tC_rRow)),
  103. tC_cRow(cute::forward<CTensor>(tC_cRow)),
  104. n(get<1>(problem_shape)),
  105. params_ptr(params_ptr) { }
  106. GTensor tC_gRow;
  107. RTensor tC_rRow;
  108. CTensor tC_cRow;
  109. Params const* params_ptr;
  110. int n;
  111. // This function is modified from VisitorRowBroadcast
  112. CUTLASS_DEVICE void
  113. begin_epilogue() {
  114. clear(tC_rRow);
  115. auto src_v = filter(tC_gRow);
  116. auto coord_v = filter(tC_cRow);
  117. auto dst_v = filter(tC_rRow);
  118. if (params_ptr->row_broadcast) {
  119. // In this case we are loading from a row vector and broadcasting
  120. CUTLASS_PRAGMA_UNROLL
  121. for (int i = 0; i < size(src_v); ++i) {
  122. bool guard = get<1>(coord_v(i)) < n;
  123. cutlass::arch::global_load<VecType, sizeof(VecType)>(
  124. dst_v(i), (void const*)&src_v(i), guard);
  125. }
  126. } else {
  127. // In this case we are loading from a scalar and broadcasting
  128. VecType filled_vec;
  129. CUTLASS_PRAGMA_UNROLL
  130. for (int i = 0; i < VecLength; i++) {
  131. reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
  132. }
  133. CUTLASS_PRAGMA_UNROLL
  134. for (int i = 0; i < size(src_v); ++i) {
  135. if (get<1>(coord_v(i)) < n) {
  136. dst_v(i) = filled_vec;
  137. }
  138. }
  139. }
  140. }
  141. template <class ElementAccumulator, int FragmentSize>
  142. CUTLASS_DEVICE auto // returns an Array
  143. visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
  144. Array<ElementAccumulator, FragmentSize> const& frg_acc) {
  145. Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
  146. return rRow_frg(column_idx);
  147. }
  148. };
  149. template <class ProblemShape>
  150. CUTLASS_DEVICE auto
  151. get_callbacks(
  152. gemm::GemmCoord threadblock_tile_offset,
  153. int thread_idx,
  154. ProblemShape problem_shape
  155. ) {
  156. Tensor mRow = make_tensor(
  157. make_gmem_ptr(params_ptr->ptr_row),
  158. problem_shape,
  159. params_ptr->dRow);
  160. // VECTOR, FRAGMENT_COLUMN
  161. Tensor tC_gRow = recast<VecType>(
  162. ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
  163. )(_,_,_0{},_0{},_0{},_0{});
  164. Tensor tC_rRow = make_tensor_like(tC_gRow);
  165. // Generate the pred tensor
  166. Tensor cRow = make_identity_tensor(mRow.shape());
  167. Tensor tC_cRow = outer_partition(
  168. ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
  169. Shape<Int<VecLength>>{},
  170. (_0{})
  171. );
  172. return Callbacks<
  173. decltype(tC_gRow), decltype(tC_rRow),
  174. decltype(tC_cRow), ProblemShape>(
  175. cute::move(tC_gRow),
  176. cute::move(tC_rRow),
  177. cute::move(tC_cRow),
  178. problem_shape,
  179. params_ptr
  180. );
  181. }
  182. };
  183. /////////////////////////////////////////////////////////////////////////////////////////////////
  184. // This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
  185. template<
  186. class ThreadMap,
  187. class Element,
  188. class StrideMNL
  189. >
  190. struct VisitorRowOrZeroBroadcast {
  191. // This struct has been modified to remove null_default (because it's always 0)
  192. struct Arguments {
  193. Element const* ptr_row = nullptr;
  194. StrideMNL dRow = {};
  195. };
  196. using Params = Arguments;
  197. template <class ProblemShape>
  198. static constexpr Params
  199. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  200. return args;
  201. }
  202. template <class ProblemShape>
  203. static size_t
  204. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  205. return 0;
  206. }
  207. struct SharedStorage {};
  208. // Global load type
  209. static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
  210. using VecType = uint_bit_t<cute::min(128, vec_bits)>;
  211. static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
  212. CUTLASS_HOST_DEVICE
  213. VisitorRowOrZeroBroadcast() { }
  214. CUTLASS_HOST_DEVICE
  215. VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
  216. : params_ptr(&params) { }
  217. Params const* params_ptr;
  218. template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  219. struct Callbacks : EmptyCallbacks {
  220. CUTLASS_DEVICE
  221. Callbacks(
  222. GTensor&& tC_gRow,
  223. RTensor&& tC_rRow,
  224. CTensor&& tC_cRow,
  225. ProblemShape problem_shape,
  226. Params const* params_ptr
  227. ):
  228. tC_gRow(cute::forward<GTensor>(tC_gRow)),
  229. tC_rRow(cute::forward<RTensor>(tC_rRow)),
  230. tC_cRow(cute::forward<CTensor>(tC_cRow)),
  231. n(get<1>(problem_shape)),
  232. params_ptr(params_ptr) { }
  233. GTensor tC_gRow;
  234. RTensor tC_rRow;
  235. CTensor tC_cRow;
  236. Params const* params_ptr;
  237. int n;
  238. // This function is modified from VisitorRowBroadcast
  239. CUTLASS_DEVICE void
  240. begin_epilogue() {
  241. clear(tC_rRow);
  242. auto src_v = filter(tC_gRow);
  243. auto coord_v = filter(tC_cRow);
  244. auto dst_v = filter(tC_rRow);
  245. if (params_ptr->ptr_row != nullptr) {
  246. // In this case we are loading from a row vector and broadcasting
  247. CUTLASS_PRAGMA_UNROLL
  248. for (int i = 0; i < size(src_v); ++i) {
  249. bool guard = get<1>(coord_v(i)) < n;
  250. cutlass::arch::global_load<VecType, sizeof(VecType)>(
  251. dst_v(i), (void const*)&src_v(i), guard);
  252. }
  253. } else {
  254. // In this case we are broadcasting 0
  255. VecType filled_vec;
  256. CUTLASS_PRAGMA_UNROLL
  257. for (int i = 0; i < VecLength; i++) {
  258. reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
  259. }
  260. CUTLASS_PRAGMA_UNROLL
  261. for (int i = 0; i < size(src_v); ++i) {
  262. if (get<1>(coord_v(i)) < n) {
  263. dst_v(i) = filled_vec;
  264. }
  265. }
  266. }
  267. }
  268. template <class ElementAccumulator, int FragmentSize>
  269. CUTLASS_DEVICE auto // returns an Array
  270. visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
  271. Array<ElementAccumulator, FragmentSize> const& frg_acc) {
  272. Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
  273. return rRow_frg(column_idx);
  274. }
  275. };
  276. template <class ProblemShape>
  277. CUTLASS_DEVICE auto
  278. get_callbacks(
  279. gemm::GemmCoord threadblock_tile_offset,
  280. int thread_idx,
  281. ProblemShape problem_shape
  282. ) {
  283. Tensor mRow = make_tensor(
  284. make_gmem_ptr(params_ptr->ptr_row),
  285. problem_shape,
  286. params_ptr->dRow);
  287. // VECTOR, FRAGMENT_COLUMN
  288. Tensor tC_gRow = recast<VecType>(
  289. ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
  290. )(_,_,_0{},_0{},_0{},_0{});
  291. Tensor tC_rRow = make_tensor_like(tC_gRow);
  292. // Generate the pred tensor
  293. Tensor cRow = make_identity_tensor(mRow.shape());
  294. Tensor tC_cRow = outer_partition(
  295. ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
  296. Shape<Int<VecLength>>{},
  297. (_0{})
  298. );
  299. return Callbacks<
  300. decltype(tC_gRow), decltype(tC_rRow),
  301. decltype(tC_cRow), ProblemShape>(
  302. cute::move(tC_gRow),
  303. cute::move(tC_rRow),
  304. cute::move(tC_cRow),
  305. problem_shape,
  306. params_ptr
  307. );
  308. }
  309. };
  310. /////////////////////////////////////////////////////////////////////////////////////////////////
  311. // Column vector broadcast
  312. template<
  313. class ThreadMap,
  314. class Element,
  315. class StrideMNL = Stride<_1,_0,_0>
  316. >
  317. struct VisitorColOrScalarBroadcast {
  318. // This struct has been modified to have a bool indicating that ptr_col is a
  319. // scalar that must be broadcast.
  320. struct Arguments {
  321. Element const* ptr_col = nullptr;
  322. bool col_broadcast = true;
  323. StrideMNL dCol = {};
  324. };
  325. using Params = Arguments;
  326. template <class ProblemShape>
  327. static constexpr Params
  328. to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
  329. return args;
  330. }
  331. template <class ProblemShape>
  332. static size_t
  333. get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
  334. return 0;
  335. }
  336. struct SharedStorage { };
  337. CUTLASS_HOST_DEVICE
  338. VisitorColOrScalarBroadcast() { }
  339. CUTLASS_HOST_DEVICE
  340. VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
  341. : params_ptr(&params) { }
  342. Params const* params_ptr;
  343. template <class GTensor, class RTensor, class CTensor, class ProblemShape>
  344. struct Callbacks : EmptyCallbacks {
  345. CUTLASS_DEVICE
  346. Callbacks(
  347. GTensor&& tC_gCol,
  348. RTensor&& tC_rCol,
  349. CTensor&& tC_cCol,
  350. ProblemShape problem_shape,
  351. Params const* params_ptr
  352. ):
  353. tC_gCol(cute::forward<GTensor>(tC_gCol)),
  354. tC_rCol(cute::forward<RTensor>(tC_rCol)),
  355. tC_cCol(cute::forward<CTensor>(tC_cCol)),
  356. m(get<0>(problem_shape)),
  357. params_ptr(params_ptr) { }
  358. GTensor tC_gCol;
  359. RTensor tC_rCol;
  360. CTensor tC_cCol;
  361. Params const* params_ptr;
  362. int m;
  363. // This function is modified from VisitorColBroadcast
  364. CUTLASS_DEVICE void
  365. begin_epilogue() {
  366. clear(tC_rCol);
  367. Tensor pred = make_tensor<bool>(shape(tC_gCol));
  368. CUTLASS_PRAGMA_UNROLL
  369. for (int i = 0; i < size(pred); ++i) {
  370. pred(i) = get<0>(tC_cCol(i)) < m;
  371. }
  372. if (params_ptr->col_broadcast) {
  373. // In this case we are loading from a column vector and broadcasting
  374. copy_if(pred, tC_gCol, tC_rCol);
  375. } else {
  376. // In this case we are loading from a scalar and broadcasting
  377. auto dst_v = filter(tC_rCol);
  378. CUTLASS_PRAGMA_UNROLL
  379. for (int i = 0; i < size(dst_v); ++i) {
  380. if (pred(i)) {
  381. dst_v(i) = *(params_ptr->ptr_col);
  382. }
  383. }
  384. }
  385. }
  386. template <class ElementAccumulator, int FragmentSize>
  387. CUTLASS_DEVICE auto // returns an Array
  388. visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
  389. Array<ElementAccumulator, FragmentSize> const& frg_acc) {
  390. Array<Element, FragmentSize> frg_col;
  391. frg_col.fill(tC_rCol(row_idx,iter_idx));
  392. return frg_col;
  393. }
  394. };
  395. template <class ProblemShape>
  396. CUTLASS_DEVICE auto
  397. get_callbacks(
  398. gemm::GemmCoord threadblock_tile_offset,
  399. int thread_idx,
  400. ProblemShape problem_shape
  401. ) {
  402. Tensor mCol = make_tensor(
  403. make_gmem_ptr(params_ptr->ptr_col),
  404. problem_shape,
  405. params_ptr->dCol);
  406. // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
  407. Tensor tC_gCol = group_modes<1,4>(
  408. ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
  409. Tensor tC_rCol = make_tensor_like(tC_gCol);
  410. // Generate the pred tensor
  411. Tensor cCol = make_identity_tensor(mCol.shape());
  412. Tensor tC_cCol = group_modes<1,4>(
  413. ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
  414. return Callbacks<
  415. decltype(tC_gCol), decltype(tC_rCol),
  416. decltype(tC_cCol), ProblemShape>(
  417. cute::move(tC_gCol),
  418. cute::move(tC_rCol),
  419. cute::move(tC_cCol),
  420. problem_shape,
  421. params_ptr
  422. );
  423. }
  424. };
  425. }