/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ // // This file is a modified excerpt of // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp // from https://github.com/NVIDIA/cutlass v3.5.0 // It has been modified to support either row/column or scalar broadcasting // where the tensor being loaded from is always passed in via a device pointer. // This lets one compiled kernel handle all cases of per-tensor or // per-channel/per-token quantization. // // This interface also allows the scales to be passed in as tensors that // consistently reside on the device, which avoids an issue with a previous // implementation where scalars needed to be on the CPU since they // were passed in via float values. This created a potential performance hazard // if scales were initially on the device, and caused torch.compile graphs // breaks when moving scales to the CPU. // #pragma once // Turn off clang-format for the entire file to keep it close to upstream // clang-format off #include "cutlass/cutlass.h" #include "cutlass/arch/barrier.h" #include "cute/tensor.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" namespace cutlass::epilogue::fusion { using namespace cute; using namespace detail; // Row vector broadcast template< // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, class StrideMNL = Stride<_0,_1,_0>, int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias (cute::is_same_v>)); // batched row vector broadcast // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem struct SharedStorage { alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; }; // This struct has been modified to have a bool indicating that ptr_row is a // scalar that must be broadcast, instead of containing a scalar that is // valid if ptr_row is null. struct Arguments { Element const* ptr_row = nullptr; bool row_broadcast = true; StrideMNL dRow = {}; }; using Params = Arguments; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { return args; } template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { return 0; } template static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast() { } CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params), smem_row(const_cast(shared_storage.smem_row.data())) { } Params params; Element* smem_row; CUTLASS_DEVICE bool is_producer_load_needed() const { return true; } CUTLASS_DEVICE bool is_C_load_needed() const { return false; } CUTLASS_DEVICE bool is_zero() const { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } template struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { CUTLASS_DEVICE ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) : gRow(cute::forward(gRow)), sRow(cute::forward(sRow)), params(params) {} GTensor gRow; // (CTA_M,CTA_N) STensor sRow; // (CTA_M,CTA_N,PIPE) Params const& params; CUTLASS_DEVICE void begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { if (!params.row_broadcast) { return; } if (issue_tma_load) { // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); // Issue the TMA bulk copy auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); // Filter so we don't issue redundant copies over stride-0 modes int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); } } }; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; return ProducerLoadCallbacks( cute::move(gRow), cute::move(sRow), params); } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) : tCrRow(cute::forward(tCrRow)), tCsRow(cute::forward(tCsRow)), params(params) {} RTensor tCrRow; // (CPY,CPY_M,CPY_N) STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) Params const& params; CUTLASS_DEVICE void previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { if (!params.row_broadcast) { fill(tCrRow, *(params.ptr_row)); return; } if (epi_m == 0) { // Assumes M-major subtile loop // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); } } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_row; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { frg_row[i] = tCrRow(epi_v * FragmentSize + i); } return frg_row; } }; template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class... Args > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) sRow, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; return ConsumerStoreCallbacks( cute::move(tCrRow), cute::move(tCsRow), params); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Column vector broadcast template< int Stages, class CtaTileShapeMNK, class Element, class StrideMNL = Stride<_1,_0,_0>, int Alignment = 128 / sizeof_bits_v > struct Sm90ColOrScalarBroadcast { static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem struct SharedStorage { }; // This struct has been modified to have a bool indicating that ptr_col is a // scalar that must be broadcast, instead of containing a scalar that is // valid if ptr_col is null. struct Arguments { Element const* ptr_col = nullptr; bool col_broadcast = true; StrideMNL dCol = {}; }; using Params = Arguments; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { return args; } template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { return 0; } template static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } CUTLASS_DEVICE bool is_producer_load_needed() const { return false; } CUTLASS_DEVICE bool is_C_load_needed() const { return false; } CUTLASS_DEVICE bool is_zero() const { return (!params.col_broadcast && *(params.ptr_col) == Element(0)); } CUTLASS_HOST_DEVICE Sm90ColOrScalarBroadcast() { } CUTLASS_HOST_DEVICE Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params params; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) : tCgCol(cute::forward(tCgCol)), tCrCol(cute::forward(tCrCol)), params(params) {} GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Params const& params; CUTLASS_DEVICE void begin() { if (!params.col_broadcast) { fill(tCrCol, *(params.ptr_col)); return; } // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) copy_aligned(filter(tCgCol), filter(tCrCol)); } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); } return frg_col; } }; template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class... Args > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) return ConsumerStoreCallbacks( cute::move(tCgCol), cute::move(tCrCol), params); } }; }