machete_interleaving_utils.cuh 1.3 KB

1234567891011121314151617181920212223242526272829303132333435
  1. #pragma once
  2. #include "cutlass/cutlass.h"
  3. #include "cute/layout.hpp"
  4. namespace machete {
  5. using namespace cute;
  6. // get an interleaved block layout where each element consecutive element has a
  7. // stride of bit_stride and the block width is blk_bit_width,
  8. // examples:
  9. // size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
  10. // size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
  11. // size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
  12. // size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
  13. template <typename T, int bit_stride, int blk_bit_width>
  14. CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
  15. static_assert(blk_bit_width % bit_stride == 0);
  16. static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
  17. constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
  18. if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
  19. // identity layout
  20. return Layout<Shape<Int<elems_per_blk>>>{};
  21. } else {
  22. constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
  23. constexpr auto num_strides = elems_per_blk / elems_per_stride;
  24. return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
  25. Stride<Int<elems_per_stride>, Int<1>>>{};
  26. }
  27. }
  28. }; // namespace machete