softmax.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cmath>
  6. #include <cute/tensor.hpp>
  7. #include <cutlass/numeric_types.h>
  8. #include "utils.h"
  9. #include "cutlass/fast_math.h"
  10. namespace flash {
  11. using namespace cute;
  12. ////////////////////////////////////////////////////////////////////////////////////////////////////
  13. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  14. __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  15. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  16. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  17. CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
  18. #pragma unroll
  19. for (int mi = 0; mi < size<0>(tensor); mi++) {
  20. summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
  21. #pragma unroll
  22. for (int ni = 1; ni < size<1>(tensor); ni++) {
  23. summary(mi) = op(summary(mi), tensor(mi, ni));
  24. }
  25. }
  26. }
  27. template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  28. __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
  29. CUTE_STATIC_ASSERT_V(size(dst) == size(src));
  30. #pragma unroll
  31. for (int i = 0; i < size(dst); i++){
  32. dst(i) = Allreduce<4>::run(src(i), op);
  33. }
  34. }
  35. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  36. __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  37. thread_reduce_<zero_init>(tensor, summary, op);
  38. quad_allreduce_(summary, summary, op);
  39. }
  40. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  41. __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
  42. MaxOp<float> max_op;
  43. reduce_<zero_init>(tensor, max, max_op);
  44. }
  45. template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  46. __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
  47. SumOp<float> sum_op;
  48. thread_reduce_<zero_init>(tensor, sum, sum_op);
  49. if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
  50. }
  51. __forceinline__ __device__ __half2 half_exp(__half2 x) {
  52. uint32_t tmp_out, tmp_in;
  53. tmp_in = reinterpret_cast<uint32_t&>(x);
  54. asm ("ex2.approx.f16x2 %0, %1;\n"
  55. : "=r"(tmp_out)
  56. : "r"(tmp_in));
  57. __half2 out = reinterpret_cast<__half2&>(tmp_out);
  58. return out;
  59. }
  60. // Apply the exp to all the elements.
  61. template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  62. __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
  63. static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  64. #pragma unroll
  65. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  66. MaxOp<float> max_op;
  67. max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
  68. #pragma unroll
  69. for (int ni = 1; ni < size<1>(tensor); ni++) {
  70. max(mi) = max_op(max(mi), tensor(mi, ni));
  71. }
  72. max(mi) = Allreduce<4>::run(max(mi), max_op);
  73. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  74. // We don't want (-inf - (-inf)) since that would give NaN.
  75. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
  76. sum(mi) = 0;
  77. #pragma unroll
  78. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  79. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  80. // max * log_2(e)) This allows the compiler to use the ffma
  81. // instruction instead of fadd and fmul separately.
  82. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  83. sum(mi) += tensor(mi, ni);
  84. }
  85. }
  86. }
  87. // Apply the exp to all the elements.
  88. template <bool Scale_max=true, bool Check_inf=true, bool Use_max_offset=false,
  89. typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  90. __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
  91. constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
  92. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  93. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  94. CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  95. #pragma unroll
  96. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  97. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  98. // We don't want (-inf - (-inf)) since that would give NaN.
  99. // If we don't have float around M_LOG2E the multiplication is done in fp64.
  100. const float max_scaled = Check_inf
  101. ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset)
  102. : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset;
  103. #pragma unroll
  104. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  105. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  106. // max * log_2(e)) This allows the compiler to use the ffma
  107. // instruction instead of fadd and fmul separately.
  108. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  109. }
  110. }
  111. }
  112. ////////////////////////////////////////////////////////////////////////////////////////////////////
  113. template <int kNRows, bool Use_max_offset_ = false>
  114. struct Softmax {
  115. constexpr static bool Use_max_offset = Use_max_offset_;
  116. // constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
  117. // constexpr static float max_offset_E = max_offset * float(M_LN2);
  118. using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
  119. TensorT row_max, row_sum;
  120. const float softmax_scale_log2;
  121. CUTLASS_DEVICE Softmax(float scale_ = 1.f) : softmax_scale_log2(scale_) {};
  122. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  123. __forceinline__ __device__ TensorT max(Tensor0 &acc_s) {
  124. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  125. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  126. static_assert(decltype(size<0>(scores))::value == kNRows);
  127. TensorT scores_scale;
  128. if constexpr (Is_first) {
  129. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  130. cute::fill(scores_scale, 1.f);
  131. } else {
  132. Tensor scores_max_prev = make_fragment_like(row_max);
  133. cute::copy(row_max, scores_max_prev);
  134. flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  135. #pragma unroll
  136. for (int mi = 0; mi < size(row_max); ++mi) {
  137. float scores_max_cur = !Check_inf
  138. ? row_max(mi)
  139. : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  140. scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  141. row_sum(mi) *= scores_scale(mi);
  142. }
  143. }
  144. return scores_scale;
  145. };
  146. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  147. __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s) {
  148. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  149. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  150. static_assert(decltype(size<0>(scores))::value == kNRows);
  151. TensorT scores_scale;
  152. if constexpr (Is_first) {
  153. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  154. flash::template scale_apply_exp2</*Scale_max=*/true, /*Check_inf=*/true, Use_max_offset>(scores, row_max, softmax_scale_log2);
  155. flash::reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
  156. cute::fill(scores_scale, 1.f);
  157. // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); }
  158. } else {
  159. // Tensor scores_max_prev = make_fragment_like(row_max);
  160. // cute::copy(row_max, scores_max_prev);
  161. // flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  162. // // if (cute::thread0()) { print_tensor(scores); printf("\n"); print_tensor(row_max); printf("\n"); }
  163. // #pragma unroll
  164. // for (int mi = 0; mi < size(row_max); ++mi) {
  165. // float scores_max_cur = !Check_inf
  166. // ? row_max(mi)
  167. // : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  168. // scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  169. // row_sum(mi) *= scores_scale(mi);
  170. // }
  171. flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf, Use_max_offset>(scores, row_max, softmax_scale_log2);
  172. // We don't do the reduce across threads here since we don't need to use the row_sum.
  173. // We do that reduce at the end when we need to normalize the softmax.
  174. flash::reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
  175. }
  176. return scores_scale;
  177. };
  178. template<bool Is_dropout=false, bool Split=false, typename Tensor0>
  179. __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float descale_v = 1.f, float rp_dropout=1.f) {
  180. constexpr static float max_offset_E = Use_max_offset ? 8.f * float(M_LN2) : 0.f;
  181. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  182. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  183. static_assert(decltype(size<0>(scores))::value == kNRows);
  184. SumOp<float> sum_op;
  185. quad_allreduce_(row_sum, row_sum, sum_op);
  186. TensorT scores_scale;
  187. #pragma unroll
  188. for (int mi = 0; mi < size(row_max); ++mi) {
  189. float sum = row_sum(mi);
  190. float inv_sum = (sum == 0.f || sum != sum) ? 0.f : descale_v / sum;
  191. row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : (row_max(mi) * softmax_scale_log2) * float(M_LN2) - max_offset_E + __logf(sum);
  192. scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
  193. }
  194. return scores_scale;
  195. };
  196. template<typename Tensor1>
  197. __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
  198. // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
  199. Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
  200. static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
  201. #pragma unroll
  202. for (int mi = 0; mi < size(row_max); ++mi) {
  203. #pragma unroll
  204. for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); }
  205. }
  206. };
  207. };
  208. } // namespace flash