123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- #pragma once
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <torch/all.h>
- // clang-format off
- // The cutlass include order matters (annoyingly)
- #include "cutlass/cutlass.h"
- #include "cute/tensor.hpp"
- #include "cutlass/tensor_ref.h"
- #include "cutlass/epilogue/collective/default_epilogue.hpp"
- #include "cutlass/epilogue/thread/linear_combination.h"
- #include "cutlass/gemm/dispatch_policy.hpp"
- #include "cutlass/gemm/collective/collective_builder.hpp"
- #include "cutlass/epilogue/collective/collective_builder.hpp"
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
- // clang-format on
- #include "cutlass_extensions/cute_utils.cuh"
- #include "machete_collective_builder.cuh"
- #include "machete_interleaving_utils.cuh"
- namespace machete {
- using namespace cute;
- struct IlvBlkLayoutAuto {};
- // This defines a prepacked layout for the B matrix, where the matrix is broken
- // up into PPBlockShape_NK blocks. The data within each block is then compactly
- // stored in memory such that when performing a TiledMMA operation with the same
- // shape as prepacked block, all the data for a given thread is contiguous in
- // memory. This allows us to use wider shared memory loads when loading B from
- // shared memory. The values within a thread are also potentially interlaeved
- // inorder to allow for more efficient upconverting.
- //
- // The contract here is that the `TiledMma` determined below matches the one
- // ultimately used in the kernel. (this is also why the other element types are
- // required along with the kernel schedule)
- template <typename ElementA_, typename ElementB_, typename ElementD_,
- typename AccumulatorT, class LayoutB, class KernelSchedule,
- typename IlvBlkLayout_ = IlvBlkLayoutAuto>
- // clang-format on
- struct PrepackedLayoutBTemplate {
- using MmaType = ElementA_;
- using ElementA = ElementA_;
- using ElementB = ElementB_;
- using ElementD = ElementD_;
- using ElementAccumulator =
- AccumulatorT; // Element type for internal accumulation
- using ElementMma = MmaType;
- // Only use interleaved layouts for subbyte weights, prmt instructions makes
- // non-interleaved layouts for 8bit+ weights efficient enough we don't need
- // iterleaved layouts
- using IlvdBlkLayout = std::conditional_t<
- std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
- std::conditional_t<sizeof_bits_v<ElementB> <= 4,
- decltype(get_interleaved_blk_layout<
- ElementB, sizeof_bits_v<ElementA>, 32>()),
- void>,
- IlvBlkLayout_>;
- // TODO (LucasWilkinson): compare the performance for other sizes
- // Prepacked block shape, smallest layout atom for loading into registers
- // (can contain multiple wgmma instructions worth of data in one block)
- // We ideally want this to be configured such that a thread can perform 128bit
- // loads, i.e. we amount of data associated with each thread within a
- // prepacked block is a multiple of 128bits, when using a cooperative sechdule
- // we have 256 threads working a single block at a time, this means each
- // thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
- // for a 4bit type this would be 128bits
- using PPBlockShape_NK = Shape<_128, _64>;
- // Create the shape of the tile anticipated to be used by the GEMM kernel,
- // when the kernel executes we will compute `Ct = Bt * At` since the
- // quantized weights (B), must be the lhs operand so the flow through
- // registers.
- // The _128 here doesn't actually impact the shape of the stored tile directly
- // but may impact the op selected by rs_op_selector
- using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
- size<1>(PPBlockShape_NK{})));
- static constexpr cute::GMMA::Major GmmaMajorB =
- gmma_rs_tag_to_major_B<LayoutB>();
- // For coop schedules we have two warp groups cooperatively issuing wgmma
- // instructions so we use 2 atoms along the M dim (one for each warpgroup)
- using AtomLayoutMNK = cute::conditional_t<
- cute::is_same_v<KernelSchedule,
- KernelTmaWarpSpecializedCooperativeMixedInput>,
- Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
- using TiledMma = decltype(cute::make_tiled_mma(
- cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
- GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
- AtomLayoutMNK{}));
- // Prepacked block, (athrid, val) -> (N,K)
- // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
- CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
- return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
- }
- // Prepacked block, (N,K) -> (athrid, val)
- // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
- CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
- return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
- }
- // Prepacked block, (athrid, val) -> (storage_offset)
- // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
- CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
- // Return iterleaved layout
- return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
- }
- // Prepacked block, (athrid, val) -> (storage_offset)
- // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
- CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
- auto layout_no_interleave =
- make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
- if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
- return layout_no_interleave;
- } else {
- // interleave by transforming FrgV into interleaved blocks where each
- // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
- // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
- // if FrgV is {A, B, C, D, E, F, G, H}
- // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
- auto frgV = get<1, 0>(layout_no_interleave);
- auto ilvdBlk = IlvdBlkLayout{};
- static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
- auto ilvd_FrgV = make_layout(
- make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
- make_stride(stride(ilvdBlk), size(ilvdBlk)));
- // Return iterleaved layout
- return make_layout(
- get<0>(layout_no_interleave),
- make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
- }
- }
- // Prepacked block, (M,K) -> (storage_offset)
- CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
- // do (M,K) -> (athrid, val) -> (storage_idx)
- return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
- }
- // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
- template <typename Shape_NKL>
- CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
- Shape_NKL shape_mkl) {
- constexpr auto block_layout = ppblock_TV_to_offset();
- // (BlocksN, BlocksK, L)
- auto blocks_shape =
- cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
- [](auto x, auto y) { return x / y; });
- // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
- auto result = make_layout(
- block_layout,
- make_layout(blocks_shape,
- compact_col_major(blocks_shape, size(block_layout))));
- // ((athrid, val), (BlocksN, BlocksK, L))
- // => ((athrid, val), (BlocksN, BlocksK), L)
- return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
- }
- // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
- template <typename Shape_NKL>
- CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
- Shape_NKL shape_mkl) {
- constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
- // (BlocksN, BlocksK, L)
- auto blocks_shape =
- cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
- [](auto x, auto y) { return x / y; });
- // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
- auto result = make_layout(
- block_layout,
- make_layout(blocks_shape,
- compact_col_major(blocks_shape, size(block_layout))));
- // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
- // BlocksK), L)
- return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
- }
- // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
- template <class Shape_NKL>
- CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
- auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
- make_layout(size<1>(PPBlockShape_NK{})));
- // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
- auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
- return tiled_A.compose(ppblock_TV_to_NK(), _);
- }
- // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
- template <class Shape_NKL>
- CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
- auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
- return blocked_product(ppblock_NK_to_TV(),
- make_layout(shape<1>(TVbNbK_to_NKL_layout)));
- }
- };
- }; // namespace machete
|