softmax.h 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cmath>
  6. #include <cute/tensor.hpp>
  7. #include <cutlass/numeric_types.h>
  8. #include "philox.cuh"
  9. #include "utils.h"
  10. namespace flash {
  11. using namespace cute;
  12. ////////////////////////////////////////////////////////////////////////////////////////////////////
  13. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  14. __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  15. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  16. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  17. CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
  18. #pragma unroll
  19. for (int mi = 0; mi < size<0>(tensor); mi++) {
  20. summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
  21. #pragma unroll
  22. for (int ni = 1; ni < size<1>(tensor); ni++) {
  23. summary(mi) = op(summary(mi), tensor(mi, ni));
  24. }
  25. }
  26. }
  27. template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  28. __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
  29. CUTE_STATIC_ASSERT_V(size(dst) == size(src));
  30. #pragma unroll
  31. for (int i = 0; i < size(dst); i++){
  32. dst(i) = Allreduce<4>::run(src(i), op);
  33. }
  34. }
  35. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
  36. __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
  37. thread_reduce_<zero_init>(tensor, summary, op);
  38. quad_allreduce_(summary, summary, op);
  39. }
  40. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  41. __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
  42. MaxOp<float> max_op;
  43. reduce_<zero_init>(tensor, max, max_op);
  44. }
  45. template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  46. __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
  47. SumOp<float> sum_op;
  48. thread_reduce_<zero_init>(tensor, sum, sum_op);
  49. }
  50. // Apply the exp to all the elements.
  51. template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  52. __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
  53. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  54. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  55. CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  56. #pragma unroll
  57. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  58. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  59. // We don't want (-inf - (-inf)) since that would give NaN.
  60. // If we don't have float around M_LOG2E the multiplication is done in fp64.
  61. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
  62. #pragma unroll
  63. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  64. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  65. // max * log_2(e)) This allows the compiler to use the ffma
  66. // instruction instead of fadd and fmul separately.
  67. // The following macro will disable the use of fma.
  68. // See: https://github.com/pytorch/pytorch/issues/121558 for more details
  69. // This macro is set in PyTorch and not FlashAttention
  70. #ifdef UNFUSE_FMA
  71. tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
  72. #else
  73. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  74. #endif
  75. }
  76. }
  77. }
  78. // Apply the exp to all the elements.
  79. template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
  80. __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
  81. static_assert(Layout0::rank == 2, "Only support 2D Tensor");
  82. static_assert(Layout1::rank == 1, "Only support 1D Tensor");
  83. CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
  84. #pragma unroll
  85. for (int mi = 0; mi < size<0>(tensor); ++mi) {
  86. MaxOp<float> max_op;
  87. max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
  88. #pragma unroll
  89. for (int ni = 1; ni < size<1>(tensor); ni++) {
  90. max(mi) = max_op(max(mi), tensor(mi, ni));
  91. }
  92. max(mi) = Allreduce<4>::run(max(mi), max_op);
  93. // If max is -inf, then all elements must have been -inf (possibly due to masking).
  94. // We don't want (-inf - (-inf)) since that would give NaN.
  95. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
  96. sum(mi) = 0;
  97. #pragma unroll
  98. for (int ni = 0; ni < size<1>(tensor); ++ni) {
  99. // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
  100. // max * log_2(e)) This allows the compiler to use the ffma
  101. // instruction instead of fadd and fmul separately.
  102. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
  103. sum(mi) += tensor(mi, ni);
  104. }
  105. SumOp<float> sum_op;
  106. sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
  107. }
  108. }
  109. ////////////////////////////////////////////////////////////////////////////////////////////////////
  110. template <int kNRows>
  111. struct Softmax {
  112. using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
  113. TensorT row_max, row_sum;
  114. __forceinline__ __device__ Softmax() {};
  115. template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
  116. __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
  117. // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
  118. Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
  119. static_assert(decltype(size<0>(scores))::value == kNRows);
  120. if (Is_first) {
  121. flash::template reduce_max</*zero_init=*/true>(scores, row_max);
  122. flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
  123. flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
  124. } else {
  125. Tensor scores_max_prev = make_fragment_like(row_max);
  126. cute::copy(row_max, scores_max_prev);
  127. flash::template reduce_max</*zero_init=*/false>(scores, row_max);
  128. // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
  129. Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
  130. static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
  131. #pragma unroll
  132. for (int mi = 0; mi < size(row_max); ++mi) {
  133. float scores_max_cur = !Check_inf
  134. ? row_max(mi)
  135. : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
  136. float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
  137. row_sum(mi) *= scores_scale;
  138. #pragma unroll
  139. for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
  140. }
  141. flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
  142. // We don't do the reduce across threads here since we don't need to use the row_sum.
  143. // We do that reduce at the end when we need to normalize the softmax.
  144. flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
  145. }
  146. };
  147. template<bool Is_dropout=false, bool Split=false, typename Tensor0>
  148. __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
  149. SumOp<float> sum_op;
  150. quad_allreduce_(row_sum, row_sum, sum_op);
  151. TensorT lse = make_fragment_like(row_sum);
  152. Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
  153. static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
  154. #pragma unroll
  155. for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
  156. float sum = row_sum(mi);
  157. float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
  158. lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
  159. float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
  160. #pragma unroll
  161. for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
  162. }
  163. return lse;
  164. };
  165. };
  166. } // namespace flash