rotary.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cute/tensor.hpp>
  6. #include "utils.h"
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. namespace flash {
  9. using namespace cute;
  10. ////////////////////////////////////////////////////////////////////////////////////////////////////
  11. template <bool Is_even_K=true, bool Clear_OOB_K=true,
  12. typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  13. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  14. __forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
  15. Tensor<Engine1, Layout1> &D,
  16. Tensor<Engine2, Layout2> const &Cos,
  17. Tensor<Engine2, Layout2> const &Sin,
  18. Tensor<Engine3, Layout3> const &identity_MN,
  19. const int max_MN, const int min_MN,
  20. const int dim, const int rotary_dim) {
  21. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  22. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  23. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  24. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  25. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  26. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
  27. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
  28. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
  29. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
  30. CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
  31. static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
  32. static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  33. Tensor rCos = make_fragment_like(Cos);
  34. Tensor rSin = make_fragment_like(Sin);
  35. Tensor rS = make_fragment_like(S);
  36. #pragma unroll
  37. for (int m = 0; m < size<1>(S); ++m) {
  38. if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
  39. #pragma unroll
  40. for (int k = 0; k < size<2>(S); ++k) {
  41. if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
  42. cute::copy(S(_, m, k), rS(_, m, k));
  43. if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
  44. cute::copy(Cos(_, m, k), rCos(_, m, k));
  45. cute::copy(Sin(_, m, k), rSin(_, m, k));
  46. Tensor S_fp32 = convert_type<float>(rS(_, m, k));
  47. Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
  48. Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
  49. #pragma unroll
  50. for (int i = 0; i < size<0>(rS) / 2; ++i) {
  51. float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
  52. float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
  53. S_fp32(2 * i) = real;
  54. S_fp32(2 * i + 1) = imag;
  55. }
  56. // Idk but I need to copy for the convert_type to work
  57. Tensor S_fp32_copy = make_fragment_like(S_fp32);
  58. cute::copy(S_fp32, S_fp32_copy);
  59. using T = typename Engine0::value_type;
  60. Tensor S_og_type = convert_type<T>(S_fp32_copy);
  61. cute::copy(S_og_type, rS(_, m, k));
  62. }
  63. cute::copy(rS(_, m, k), D(_, m, k));
  64. } else if (Clear_OOB_K) {
  65. cute::clear(D(_, m, k));
  66. }
  67. }
  68. }
  69. }
  70. }
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. template <bool Is_even_K=true, bool Clear_OOB_K=true,
  73. typename Engine0, typename Layout0, typename Engine1, typename Layout1,
  74. typename Engine2, typename Layout2, typename Engine3, typename Layout3>
  75. __forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
  76. Tensor<Engine1, Layout1> &D,
  77. Tensor<Engine2, Layout2> const &Cos,
  78. Tensor<Engine2, Layout2> const &Sin,
  79. Tensor<Engine3, Layout3> const &identity_MN,
  80. const int max_MN, const int min_MN,
  81. const int dim, const int rotary_dim) {
  82. CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
  83. CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
  84. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
  85. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
  86. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
  87. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
  88. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
  89. CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
  90. CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
  91. CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
  92. CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
  93. static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
  94. Tensor rCos = make_fragment_like(Cos);
  95. Tensor rSin = make_fragment_like(Sin);
  96. Tensor rS = make_fragment_like(S);
  97. Tensor rS_other = make_fragment_like(rS(_, 0, 0));
  98. #pragma unroll
  99. for (int m = 0; m < size<1>(S); ++m) {
  100. if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
  101. #pragma unroll
  102. for (int k = 0; k < size<2>(S); ++k) {
  103. if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
  104. cute::copy(S(_, m, k), rS(_, m, k));
  105. if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
  106. const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
  107. Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
  108. cute::copy(gS_other, rS_other);
  109. // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
  110. Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
  111. Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
  112. cute::copy(gCos, rCos(_, m, k));
  113. cute::copy(gSin, rSin(_, m, k));
  114. // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
  115. Tensor S_fp32 = convert_type<float>(rS(_, m, k));
  116. Tensor S_other_fp32 = convert_type<float>(rS_other);
  117. Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
  118. Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
  119. #pragma unroll
  120. for (int i = 0; i < size<0>(rS); ++i) {
  121. S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
  122. }
  123. // Idk but I need to copy for the convert_type to work
  124. Tensor S_fp32_copy = make_fragment_like(S_fp32);
  125. cute::copy(S_fp32, S_fp32_copy);
  126. using T = typename Engine0::value_type;
  127. Tensor S_og_type = convert_type<T>(S_fp32_copy);
  128. cute::copy(S_og_type, rS(_, m, k));
  129. // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
  130. }
  131. cute::copy(rS(_, m, k), D(_, m, k));
  132. } else if (Clear_OOB_K) {
  133. cute::clear(D(_, m, k));
  134. }
  135. }
  136. }
  137. }
  138. }
  139. ////////////////////////////////////////////////////////////////////////////////////////////////////
  140. } // namespace flash