#pragma once #include #include #include // clang-format off // The cutlass include order matters (annoyingly) #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" // clang-format on #include "cutlass_extensions/cute_utils.cuh" #include "machete_collective_builder.cuh" #include "machete_interleaving_utils.cuh" namespace machete { using namespace cute; struct IlvBlkLayoutAuto {}; // This defines a prepacked layout for the B matrix, where the matrix is broken // up into PPBlockShape_NK blocks. The data within each block is then compactly // stored in memory such that when performing a TiledMMA operation with the same // shape as prepacked block, all the data for a given thread is contiguous in // memory. This allows us to use wider shared memory loads when loading B from // shared memory. The values within a thread are also potentially interlaeved // inorder to allow for more efficient upconverting. // // The contract here is that the `TiledMma` determined below matches the one // ultimately used in the kernel. (this is also why the other element types are // required along with the kernel schedule) template // clang-format on struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; using ElementAccumulator = AccumulatorT; // Element type for internal accumulation using ElementMma = MmaType; // Only use interleaved layouts for subbyte weights, prmt instructions makes // non-interleaved layouts for 8bit+ weights efficient enough we don't need // iterleaved layouts using IlvdBlkLayout = std::conditional_t< std::is_same_v, std::conditional_t <= 4, decltype(get_interleaved_blk_layout< ElementB, sizeof_bits_v, 32>()), void>, IlvBlkLayout_>; // TODO (LucasWilkinson): compare the performance for other sizes // Prepacked block shape, smallest layout atom for loading into registers // (can contain multiple wgmma instructions worth of data in one block) // We ideally want this to be configured such that a thread can perform 128bit // loads, i.e. we amount of data associated with each thread within a // prepacked block is a multiple of 128bits, when using a cooperative sechdule // we have 256 threads working a single block at a time, this means each // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, // for a 4bit type this would be 128bits using PPBlockShape_NK = Shape<_128, _64>; // Create the shape of the tile anticipated to be used by the GEMM kernel, // when the kernel executes we will compute `Ct = Bt * At` since the // quantized weights (B), must be the lhs operand so the flow through // registers. // The _128 here doesn't actually impact the shape of the stored tile directly // but may impact the op selected by rs_op_selector using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, size<1>(PPBlockShape_NK{}))); static constexpr cute::GMMA::Major GmmaMajorB = gmma_rs_tag_to_major_B(); // For coop schedules we have two warp groups cooperatively issuing wgmma // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< cute::is_same_v, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutMNK{})); // Prepacked block, (athrid, val) -> (N,K) // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); } // Prepacked block, (N,K) -> (athrid, val) // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); } // Prepacked block, (athrid, val) -> (storage_offset) // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { // Return iterleaved layout return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); } // Prepacked block, (athrid, val) -> (storage_offset) // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { auto layout_no_interleave = make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); if constexpr (std::is_same_v) { return layout_no_interleave; } else { // interleave by transforming FrgV into interleaved blocks where each // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) // if FrgV is {A, B, C, D, E, F, G, H} // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} auto frgV = get<1, 0>(layout_no_interleave); auto ilvdBlk = IlvdBlkLayout{}; static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); auto ilvd_FrgV = make_layout( make_shape(shape(ilvdBlk), Int{}), make_stride(stride(ilvdBlk), size(ilvdBlk))); // Return iterleaved layout return make_layout( get<0>(layout_no_interleave), make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); } } // Prepacked block, (M,K) -> (storage_offset) CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { // do (M,K) -> (athrid, val) -> (storage_idx) return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); } // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( Shape_NKL shape_mkl) { constexpr auto block_layout = ppblock_TV_to_offset(); // (BlocksN, BlocksK, L) auto blocks_shape = cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), [](auto x, auto y) { return x / y; }); // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) auto result = make_layout( block_layout, make_layout(blocks_shape, compact_col_major(blocks_shape, size(block_layout)))); // ((athrid, val), (BlocksN, BlocksK, L)) // => ((athrid, val), (BlocksN, BlocksK), L) return group<1, 3>(result(_, repeat(result)>(_))); } // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( Shape_NKL shape_mkl) { constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); // (BlocksN, BlocksK, L) auto blocks_shape = cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), [](auto x, auto y) { return x / y; }); // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) auto result = make_layout( block_layout, make_layout(blocks_shape, compact_col_major(blocks_shape, size(block_layout)))); // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, // BlocksK), L) return group<1, 3>(result(_, repeat(result)>(_))); } // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) template CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), make_layout(size<1>(PPBlockShape_NK{}))); // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); return tiled_A.compose(ppblock_TV_to_NK(), _); } // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) template CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); return blocked_product(ppblock_NK_to_TV(), make_layout(shape<1>(TVbNbK_to_NKL_layout))); } }; }; // namespace machete