softmax.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cmath>
  6. #include <cute/tensor.hpp>
  7. #include <cutlass/numeric_types.h>
  8. #include "utils.h"
  9. namespace flash {
  10. using namespace cute;
  11. ////////////////////////////////////////////////////////////////////////////////////////////////////
  12. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  13. __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  14. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  15. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  16. CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
  17. #pragma unroll
  18. for (int mi = 0; mi < size<0>(tensor); mi++) {
  19. summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
  20. #pragma unroll
  21. for (int ni = 1; ni < size<1>(tensor); ni++) {
  22. summary(mi) = op(summary(mi), tensor(mi, ni));
  23. }
  24. }
  25. }
  26. template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  27. __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
  28. CUTE_STATIC_ASSERT_V(size(dst) == size(src));
  29. #pragma unroll
  30. for (int i = 0; i < size(dst); i++){
  31. dst(i) = Allreduce<4>::run(src(i), op);
  32. }
  33. }
  34. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  35. __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  36. thread_reduce_<zero_init>(tensor, summary, op);
  37. quad_allreduce_(summary, summary, op);
  38. }
  39. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  40. __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
  41. MaxOp<float> max_op;
  42. reduce_<zero_init>(tensor, max, max_op);
  43. }
  44. template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  45. __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
  46. SumOp<float> sum_op;
  47. thread_reduce_<zero_init>(tensor, sum, sum_op);
  48. if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
  49. }
  50. __forceinline__ __device__ __half2 half_exp(__half2 x) {
  51. uint32_t tmp_out, tmp_in;
  52. tmp_in = reinterpret_cast<uint32_t&>(x);
  53. asm ("ex2.approx.f16x2 %0, %1;\n"
  54. : "=r"(tmp_out)
  55. : "r"(tmp_in));
  56. __half2 out = reinterpret_cast<__half2&>(tmp_out);
  57. return out;
  58. }
  59. // Apply the exp to all the elements.
  60. template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  61. __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
  62. static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  63. #pragma unroll
  64. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  65. MaxOp<float> max_op;
  66. max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
  67. #pragma unroll
  68. for (int ni = 1; ni < size<1>(tensor); ni++) {
  69. max(mi) = max_op(max(mi), tensor(mi, ni));
  70. }
  71. max(mi) = Allreduce<4>::run(max(mi), max_op);
  72. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  73. // We don't want (-inf - (-inf)) since that would give NaN.
  74. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
  75. sum(mi) = 0;
  76. #pragma unroll
  77. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  78. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  79. // max * log_2(e)) This allows the compiler to use the ffma
  80. // instruction instead of fadd and fmul separately.
  81. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  82. sum(mi) += tensor(mi, ni);
  83. }
  84. }
  85. }
  86. // Apply the exp to all the elements.
  87. template <bool Scale_max=true, bool Check_inf=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  88. __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
  89. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  90. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  91. CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  92. #pragma unroll
  93. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  94. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  95. // We don't want (-inf - (-inf)) since that would give NaN.
  96. // If we don't have float around M_LOG2E the multiplication is done in fp64.
  97. const float max_scaled = Check_inf
  98. ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E))))
  99. : (max(mi) * (Scale_max ? scale : float(M_LOG2E)));
  100. #pragma unroll
  101. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  102. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  103. // max * log_2(e)) This allows the compiler to use the ffma
  104. // instruction instead of fadd and fmul separately.
  105. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  106. }
  107. }
  108. }
  109. ////////////////////////////////////////////////////////////////////////////////////////////////////
  110. template <int kNRows>
  111. struct Softmax {
  112. using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
  113. TensorT row_max, row_sum;
  114. CUTLASS_DEVICE Softmax() {};
  115. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  116. __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
  117. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  118. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  119. static_assert(decltype(size<0>(scores))::value == kNRows);
  120. TensorT scores_scale;
  121. if constexpr (Is_first) {
  122. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  123. cute::fill(scores_scale, 1.f);
  124. } else {
  125. Tensor scores_max_prev = make_fragment_like(row_max);
  126. cute::copy(row_max, scores_max_prev);
  127. flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  128. #pragma unroll
  129. for (int mi = 0; mi < size(row_max); ++mi) {
  130. float scores_max_cur = !Check_inf
  131. ? row_max(mi)
  132. : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  133. scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  134. row_sum(mi) *= scores_scale(mi);
  135. }
  136. }
  137. return scores_scale;
  138. };
  139. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  140. __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
  141. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  142. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  143. static_assert(decltype(size<0>(scores))::value == kNRows);
  144. TensorT scores_scale;
  145. if constexpr (Is_first) {
  146. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  147. flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2);
  148. flash::reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
  149. cute::fill(scores_scale, 1.f);
  150. // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); }
  151. } else {
  152. // Tensor scores_max_prev = make_fragment_like(row_max);
  153. // cute::copy(row_max, scores_max_prev);
  154. // flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  155. // // if (cute::thread0()) { print_tensor(scores); printf("\n"); print_tensor(row_max); printf("\n"); }
  156. // #pragma unroll
  157. // for (int mi = 0; mi < size(row_max); ++mi) {
  158. // float scores_max_cur = !Check_inf
  159. // ? row_max(mi)
  160. // : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  161. // scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  162. // row_sum(mi) *= scores_scale(mi);
  163. // }
  164. flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf>(scores, row_max, softmax_scale_log2);
  165. // We don't do the reduce across threads here since we don't need to use the row_sum.
  166. // We do that reduce at the end when we need to normalize the softmax.
  167. flash::reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
  168. }
  169. return scores_scale;
  170. };
  171. template<bool Is_dropout=false, bool Split=false, typename Tensor0>
  172. __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) {
  173. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  174. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  175. static_assert(decltype(size<0>(scores))::value == kNRows);
  176. SumOp<float> sum_op;
  177. quad_allreduce_(row_sum, row_sum, sum_op);
  178. TensorT scores_scale;
  179. #pragma unroll
  180. for (int mi = 0; mi < size(row_max); ++mi) {
  181. float sum = row_sum(mi);
  182. float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;
  183. row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
  184. scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
  185. }
  186. return scores_scale;
  187. };
  188. template<typename Tensor1>
  189. __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
  190. // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
  191. Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
  192. static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
  193. #pragma unroll
  194. for (int mi = 0; mi < size(row_max); ++mi) {
  195. #pragma unroll
  196. for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); }
  197. }
  198. };
  199. };
  200. } // namespace flash