1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- #pragma once
- #include "machete_prepack_kernel.cuh"
- #include "cutlass_extensions/torch_utils.hpp"
- namespace machete {
- template <typename PrepackedLayoutB>
- torch::Tensor prepack_impl(torch::Tensor const B) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
- using ElementB = typename PrepackedLayoutB::ElementB;
- using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
- auto device = B.device();
- auto stream = at::cuda::getCurrentCUDAStream(device.index());
- auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
- // elements per storage item for B
- auto eles_per_storage =
- (B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
- // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
- // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
- auto Bt_packed = B.t();
- TORCH_CHECK(
- (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
- "B.shape[0] (in terms of unpacked elements) must be a multiple of ",
- size<1>(PPBlockShape_NK{}));
- TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
- "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
- using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
- auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
- // convert (N,packed_K,L) layout to (N,K,L) layout
- // in effect we want to do: blocked_product(layout_Bt_packed,
- // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
- // Step<_1, _0, _2>{}));
- // but blocked_product does not support dynamic strides so we implement the
- // equivalent manually,
- // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
- // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
- // when s1 == 1
- TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
- // clang-format off
- auto const layout_Bt = make_layout(
- transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
- return idx == 1 ? ele * eles_per_storage : ele;
- }),
- transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
- return idx != 1 ? ele * eles_per_storage : ele;
- }));
- // clang-format on
- // Allocate output
- torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
- prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
- static_cast<ElementB*>(D.mutable_data_ptr()));
- return D;
- };
- template <typename ElementA, typename ElementB, typename ElementD,
- typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
- typename ZeroT = cutlass::half_t>
- struct PrepackBDispatcher {
- static torch::Tensor dispatch(torch::Tensor B);
- };
- }; // namespace machete
|