softmax.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cmath>
  6. #include <cute/tensor.hpp>
  7. #include <cutlass/numeric_types.h>
  8. #include "utils.h"
  9. namespace flash {
  10. using namespace cute;
  11. ////////////////////////////////////////////////////////////////////////////////////////////////////
  12. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  13. __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  14. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  15. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  16. CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
  17. #pragma unroll
  18. for (int ni = 0; ni < size<1>(tensor); ni++) {
  19. #pragma unroll
  20. for (int mi = 0; mi < size<0>(tensor); mi++) {
  21. summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni));
  22. }
  23. }
  24. }
  25. template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  26. __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
  27. CUTE_STATIC_ASSERT_V(size(dst) == size(src));
  28. #pragma unroll
  29. for (int i = 0; i < size(dst); i++) {
  30. dst(i) = Allreduce<4>::run(src(i), op);
  31. }
  32. }
  33. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  34. __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  35. thread_reduce_<zero_init>(tensor, summary, op);
  36. quad_allreduce_(summary, summary, op);
  37. }
  38. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  39. __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
  40. MaxOp<float> max_op;
  41. reduce_<zero_init>(tensor, max, max_op);
  42. }
  43. template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  44. __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
  45. SumOp<float> sum_op;
  46. thread_reduce_<zero_init>(tensor, sum, sum_op);
  47. if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
  48. }
  49. // Apply the exp to all the elements.
  50. template <bool Scale_max=true, bool Check_inf=true, int Max_offset=0,
  51. typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  52. __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
  53. // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256].
  54. // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow.
  55. static constexpr float max_offset = float(Max_offset); // We can only template on int, not float
  56. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  57. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  58. CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  59. #pragma unroll
  60. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  61. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  62. // We don't want (-inf - (-inf)) since that would give NaN.
  63. const float max_scaled = Check_inf
  64. ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset)
  65. : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset;
  66. #pragma unroll
  67. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  68. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  69. // max * log_2(e)). This allows the compiler to use the ffma
  70. // instruction instead of fadd and fmul separately.
  71. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  72. }
  73. }
  74. }
  75. ////////////////////////////////////////////////////////////////////////////////////////////////////
  76. template <int kNRows, int Max_offset=0>
  77. struct Softmax {
  78. using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
  79. TensorT row_max, row_sum;
  80. float const softmax_scale_log2;
  81. CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {};
  82. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  83. __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) {
  84. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  85. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  86. static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows);
  87. TensorT scores_scale;
  88. if constexpr (Is_first) {
  89. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  90. cute::fill(scores_scale, 1.f);
  91. } else {
  92. Tensor scores_max_prev = make_fragment_like(row_max);
  93. cute::copy(row_max, scores_max_prev);
  94. flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  95. #pragma unroll
  96. for (int mi = 0; mi < size(row_max); ++mi) {
  97. float scores_max_cur = !Check_inf
  98. ? row_max(mi)
  99. : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  100. scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  101. row_sum(mi) *= scores_scale(mi);
  102. }
  103. }
  104. return scores_scale;
  105. };
  106. template<bool Is_first, bool Check_inf=false, typename Tensor0>
  107. __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) {
  108. // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
  109. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  110. static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows);
  111. flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf, Max_offset>(scores, row_max, softmax_scale_log2);
  112. // We don't do the reduce across threads here since we don't need to use the row_sum.
  113. // We do that reduce at the end when we need to normalize the softmax.
  114. flash::reduce_sum</*zero_init=*/Is_first, /*warp_reduce=*/false>(scores, row_sum);
  115. };
  116. __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) {
  117. SumOp<float> sum_op;
  118. quad_allreduce_(row_sum, row_sum, sum_op);
  119. TensorT scores_scale;
  120. #pragma unroll
  121. for (int mi = 0; mi < size(row_sum); ++mi) {
  122. float sum = row_sum(mi);
  123. float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;
  124. scores_scale(mi) = inv_sum * final_scale;
  125. // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount.
  126. if constexpr (Max_offset != 0) {
  127. static constexpr float sum_scale = 1.f / float(1 << Max_offset);
  128. sum *= sum_scale;
  129. }
  130. row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
  131. }
  132. return scores_scale;
  133. };
  134. template<typename Tensor1>
  135. __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
  136. // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
  137. Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
  138. static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows);
  139. #pragma unroll
  140. for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
  141. #pragma unroll
  142. for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); }
  143. }
  144. };
  145. };
  146. } // namespace flash