machete_prepack_launcher.cuh 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #pragma once
  2. #include "machete_prepack_kernel.cuh"
  3. #include "cutlass_extensions/torch_utils.hpp"
  4. namespace machete {
  5. template <typename PrepackedLayoutB>
  6. torch::Tensor prepack_impl(torch::Tensor const B) {
  7. const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
  8. using ElementB = typename PrepackedLayoutB::ElementB;
  9. using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
  10. auto device = B.device();
  11. auto stream = at::cuda::getCurrentCUDAStream(device.index());
  12. auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
  13. // elements per storage item for B
  14. auto eles_per_storage =
  15. (B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
  16. // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
  17. // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
  18. auto Bt_packed = B.t();
  19. TORCH_CHECK(
  20. (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
  21. "B.shape[0] (in terms of unpacked elements) must be a multiple of ",
  22. size<1>(PPBlockShape_NK{}));
  23. TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
  24. "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
  25. using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
  26. auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
  27. // convert (N,packed_K,L) layout to (N,K,L) layout
  28. // in effect we want to do: blocked_product(layout_Bt_packed,
  29. // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
  30. // Step<_1, _0, _2>{}));
  31. // but blocked_product does not support dynamic strides so we implement the
  32. // equivalent manually,
  33. // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
  34. // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
  35. // when s1 == 1
  36. TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
  37. // clang-format off
  38. auto const layout_Bt = make_layout(
  39. transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
  40. return idx == 1 ? ele * eles_per_storage : ele;
  41. }),
  42. transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
  43. return idx != 1 ? ele * eles_per_storage : ele;
  44. }));
  45. // clang-format on
  46. // Allocate output
  47. torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
  48. prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
  49. static_cast<ElementB*>(D.mutable_data_ptr()));
  50. return D;
  51. };
  52. template <typename ElementA, typename ElementB, typename ElementD,
  53. typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
  54. typename ZeroT = cutlass::half_t>
  55. struct PrepackBDispatcher {
  56. static torch::Tensor dispatch(torch::Tensor B);
  57. };
  58. }; // namespace machete