torch_utils.hpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #pragma once
  2. #include <torch/all.h>
  3. #include "cute/layout.hpp"
  4. #include "cutlass/layout/matrix.h"
  5. #include "cutlass/bfloat16.h"
  6. #include "cutlass/half.h"
  7. using ColumnMajor = typename cutlass::layout::ColumnMajor;
  8. using RowMajor = typename cutlass::layout::RowMajor;
  9. namespace cute {
  10. namespace detail {
  11. template <class T, class F, class G, int... I>
  12. CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
  13. seq<I...>) {
  14. return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
  15. }
  16. template <class F, int... I>
  17. CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
  18. return make_shape(f(I)...);
  19. }
  20. }; // namespace detail
  21. template <class T, class F>
  22. CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
  23. if constexpr (cute::is_tuple<T>::value) {
  24. return detail::tapply_with_idx(
  25. t, f, [](auto const&... a) { return cute::make_tuple(a...); },
  26. tuple_seq<T>{});
  27. } else {
  28. return f(t);
  29. }
  30. CUTE_GCC_UNREACHABLE;
  31. }
  32. // calls: make_shape(f(0), f(1), ..., f(N-1))
  33. template <int N, class F>
  34. CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
  35. return detail::make_shape_from_idx(f, make_seq<N>{});
  36. }
  37. }; // namespace cute
  38. // Make a layout from a tensor with `rank(Stride{})`, where the shape is the
  39. // shape of the passed in tensor and the strides are of type `Stride` and
  40. // contain the strides of the passed in tensor, checking that any static strides
  41. // in `Stride{}` match the strides of the passed in tensor.
  42. // If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
  43. // strides are set to be 0 or 1.
  44. template <typename Stride>
  45. static inline auto make_cute_layout(torch::Tensor const& tensor,
  46. std::string_view name = "tensor") {
  47. TORCH_CHECK(tensor.dim() <= rank(Stride{}));
  48. auto stride = cute::transform_with_idx(
  49. Stride{}, [&](auto const& stride_ele, auto const& idx) {
  50. using StrideEle = std::decay_t<decltype(stride_ele)>;
  51. if (idx < tensor.dim()) {
  52. if constexpr (cute::is_static_v<StrideEle>) {
  53. TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
  54. name, ".stride(", idx, ") to be ", StrideEle::value);
  55. return StrideEle{};
  56. } else {
  57. if (tensor.size(idx) == 1) {
  58. // use 0 stride for dim with size 1, this is easier for
  59. // cute/cutlass to optimize (helps the TMA code flatten dims)
  60. return StrideEle{0};
  61. } else {
  62. return tensor.stride(idx);
  63. }
  64. }
  65. } else {
  66. // Extra strides are assumed to be 0 or 1
  67. if constexpr (cute::is_static_v<StrideEle>) {
  68. static_assert(StrideEle::value == 0 || StrideEle::value == 1);
  69. }
  70. return StrideEle{};
  71. }
  72. });
  73. auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
  74. if (idx < tensor.dim())
  75. return tensor.size(idx);
  76. else
  77. return int64_t(1);
  78. });
  79. return make_layout(shape, stride);
  80. }
  81. template <typename Stride>
  82. static inline auto maybe_make_cute_layout(
  83. c10::optional<torch::Tensor> const& tensor,
  84. std::string_view name = "tensor") {
  85. using Layout = decltype(make_cute_layout<Stride>(*tensor));
  86. if (tensor) {
  87. return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
  88. } else {
  89. return std::optional<Layout>{};
  90. }
  91. }
  92. //
  93. // Torch Type to Cutlass Type (equivalent_cutlass_type)
  94. //
  95. template <typename T>
  96. struct equivalent_cutlass_type {
  97. using type = T;
  98. };
  99. template <typename T>
  100. using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
  101. template <>
  102. struct equivalent_cutlass_type<c10::Half> {
  103. using type = cutlass::half_t;
  104. };
  105. template <>
  106. struct equivalent_cutlass_type<c10::BFloat16> {
  107. using type = cutlass::bfloat16_t;
  108. };
  109. //
  110. // equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
  111. //
  112. // Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
  113. // c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
  114. template <typename T>
  115. struct equivalent_scalar_type {
  116. using type = T;
  117. };
  118. template <typename T>
  119. using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
  120. template <>
  121. struct equivalent_scalar_type<cutlass::half_t> {
  122. using type = c10::Half;
  123. };
  124. template <>
  125. struct equivalent_scalar_type<cutlass::bfloat16_t> {
  126. using type = c10::BFloat16;
  127. };
  128. // get equivalent c10::ScalarType tag from compile time type
  129. template <typename T>
  130. static inline constexpr c10::ScalarType equivalent_scalar_type_v =
  131. c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;