machete_prepack_kernel.cuh 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #pragma once
  2. #include "machete_mm_kernel.cuh"
  3. #include "cutlass_extensions/cute_utils.cuh"
  4. #include "cutlass_extensions/torch_utils.hpp"
  5. namespace machete {
  6. template <typename TileShapeNKL, typename ElementB, typename BInTensor,
  7. typename BTiledOutTensor>
  8. static __global__ void prepack_B_kernel(BInTensor B_in,
  9. BTiledOutTensor B_tiled_out) {
  10. auto tB_in = local_tile(B_in, TileShapeNKL{},
  11. make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
  12. auto tB_out = B_tiled_out(make_coord(_, _),
  13. make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
  14. auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
  15. Layout<Shape<_4, _32>, Stride<_32, _1>>{},
  16. Layout<Shape<_1, _2>>{});
  17. auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
  18. Tensor thr_tile_S = thr_copy.partition_S(tB_in);
  19. Tensor thr_tile_D = thr_copy.partition_D(tB_out);
  20. // Construct a register-backed Tensor with the same shape as each thread's
  21. // partition
  22. auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
  23. // Copy from GMEM to RMEM and from RMEM to GMEM
  24. copy(tiled_copy, thr_tile_S, fragment);
  25. copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
  26. }
  27. template <typename PrepackedLayoutB, typename InLayout>
  28. static void prepack_B(cudaStream_t stream,
  29. typename PrepackedLayoutB::ElementB const* B_in_ptr,
  30. InLayout B_layout,
  31. typename PrepackedLayoutB::ElementB* B_out_ptr) {
  32. using TileShapeNKL =
  33. decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
  34. auto ilvd_NKbNbKL_to_offset =
  35. PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
  36. TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
  37. TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
  38. TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
  39. auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
  40. auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
  41. auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
  42. auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
  43. auto B_tiled_out =
  44. make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
  45. prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
  46. <<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
  47. }
  48. }; // namespace machete