combine.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #pragma once
  2. #include <cute/tensor.hpp>
  3. #include <cutlass/cutlass.h>
  4. #include "cutlass/layout/layout.h"
  5. #include <cutlass/array.h>
  6. #include <cutlass/numeric_types.h>
  7. #include "kernel_traits.h"
  8. #include "utils.h"
  9. namespace flash {
  10. using namespace cute;
  11. ////////////////////////////////////////////////////////////////////////////////////////////////////
  12. template <class Element, class SmemShape, class SmemShapeMaxSplits>
  13. struct SharedStorageLSE {
  14. cute::array_aligned<Element, cute::size_v<SmemShape>> smem_lse;
  15. cute::array_aligned<bool, cute::size_v<SmemShapeMaxSplits>> smem_valid_splits;
  16. };
  17. // DONT use Kernel_traits here to avoid redundant compilation.
  18. // template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
  19. template<typename Element, typename ElementAccum, int kHeadDim, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
  20. __global__ void combine_attn_seqk_parallel(Params const params) {
  21. // using Element = typename Kernel_traits::OutputType;
  22. // using ElementAccum = typename Kernel_traits::ElementAccum;
  23. using index_t = int64_t; // Kernel_traits::index_t
  24. constexpr int kMaxSplits = 1 << Log_max_splits;
  25. // constexpr int kHeadDim = Kernel_traits::kHeadDim;
  26. constexpr int kNThreads = 128; //Kernel_traits::kNThreads;
  27. static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
  28. static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
  29. static_assert(kNThreads == 128, "We assume that each block has 128 threads");
  30. // Shared memory.
  31. // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
  32. //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1];
  33. extern __shared__ char smem_[];
  34. using SharedStorage = SharedStorageLSE<ElementAccum, Shape<Int<kMaxSplits>, Int<kBlockM+1>>, Shape<Int<kMaxSplits>>>;
  35. SharedStorage &shared_storage =
  36. *reinterpret_cast<SharedStorage *>(smem_);
  37. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kMaxSplits>, Int<kBlockM+1>>{});
  38. Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape<Int<kMaxSplits>>{});
  39. // The thread and block index.
  40. const int tidx = threadIdx.x;
  41. const int bidx = blockIdx.x;
  42. const index_t lse_size = params.b * params.h * params.seqlen_q;
  43. //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
  44. const index_t row_offset_lse = bidx * kBlockM;
  45. Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
  46. Shape<Int<kMaxSplits>, Int<kBlockM>>{},
  47. make_stride(lse_size, _1{}));
  48. // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
  49. // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
  50. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
  51. Shape<Int<kBlockM>>{}, Stride<_1>{});
  52. // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
  53. Layout flat_layout = make_layout(lse_size);
  54. Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
  55. auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
  56. Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
  57. Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
  58. Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
  59. constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
  60. // Read the LSE values from gmem and store them in shared memory, then transpose them.
  61. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
  62. #pragma unroll
  63. for (int l = 0; l < kNLsePerThread; ++l) {
  64. const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
  65. const int col = tidx % kBlockM;
  66. ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
  67. if (row < kMaxSplits) { sLSE(row,col) = lse; }
  68. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
  69. }
  70. __syncthreads();
  71. // Reduce along the kBlockM dimension to determine valid splits (store in SMEM)
  72. // One thread per split. Know NumThreads = 128 >= NumMaxSplits
  73. if (tidx < kMaxSplits) {
  74. bool is_valid_split = false;
  75. #pragma unroll
  76. for (int col = 0; col < kBlockM; ++col) {
  77. if(sLSE(tidx,col) != -INFINITY) {
  78. is_valid_split = true;
  79. }
  80. }
  81. sValidSplits(tidx) = is_valid_split;
  82. }
  83. __syncthreads();
  84. // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
  85. Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
  86. constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
  87. // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
  88. // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
  89. // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
  90. // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
  91. // static_assert(kThreadsPerSplit <= 32);
  92. static_assert(kRowsPerLoadTranspose <= 32);
  93. static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
  94. #pragma unroll
  95. for (int l = 0; l < kNLsePerThread; ++l) {
  96. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  97. const int col = tidx / kRowsPerLoadTranspose;
  98. //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
  99. lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY;
  100. }
  101. //return;
  102. // Compute the logsumexp of the LSE along the split dimension.
  103. ElementAccum lse_max = lse_accum(0);
  104. #pragma unroll
  105. for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
  106. MaxOp<float> max_op;
  107. lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
  108. lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
  109. float lse_sum = expf(lse_accum(0) - lse_max);
  110. #pragma unroll
  111. for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
  112. SumOp<float> sum_op;
  113. lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
  114. // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
  115. // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
  116. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
  117. // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
  118. if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
  119. if (params.unpadded_lse) {
  120. const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
  121. if (lse_offset < lse_size) {
  122. gLSE_unpadded(lse_offset) = lse_logsum;
  123. }
  124. } else {
  125. gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
  126. }
  127. }
  128. //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum);
  129. // Store the scales exp(lse - lse_logsum) in shared memory.
  130. #pragma unroll
  131. for (int l = 0; l < kNLsePerThread; ++l) {
  132. const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
  133. const int col = tidx / kRowsPerLoadTranspose;
  134. if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); }
  135. }
  136. __syncthreads();
  137. const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
  138. Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
  139. Shape<Int<kBlockM>, Int<kHeadDim>>{},
  140. Stride<Int<kHeadDim>, _1>{});
  141. constexpr int kBlockN = kNThreads / kBlockM;
  142. using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
  143. using GmemTiledCopyOaccum = decltype(
  144. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  145. GmemLayoutAtomOaccum{},
  146. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  147. GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
  148. auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
  149. Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
  150. Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
  151. Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
  152. clear(tOrO);
  153. // Predicates
  154. Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
  155. //if (cute::thread0()) print_tensor (cOaccum);
  156. // Repeat the partitioning with identity layouts
  157. Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
  158. Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
  159. if (!Is_even_K) {
  160. #pragma unroll
  161. for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
  162. }
  163. // Load Oaccum in then scale and accumulate to O
  164. for (int split = 0; split < params.num_splits; ++split) {
  165. // DONT copy in Oaccum if lse(split) = -inf for all kBlockM.
  166. if(sValidSplits(split)) {
  167. flash::copy</*Is_even_MN=*/false, Is_even_K>(
  168. gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
  169. );
  170. #pragma unroll
  171. for (int m = 0; m < size<1>(tOrOaccum); ++m) {
  172. int row = get<0>(tOcOaccum(0, m, 0));
  173. ElementAccum lse_scale = sLSE(split,row);
  174. if (lse_scale != 0.f) {
  175. #pragma unroll
  176. for (int k = 0; k < size<2>(tOrOaccum); ++k) {
  177. #pragma unroll
  178. for (int i = 0; i < size<0>(tOrOaccum); ++i) {
  179. tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
  180. //tOrO(i, m, k) += tOrOaccum(i, m, k);
  181. }
  182. }
  183. }
  184. //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); }
  185. }
  186. }
  187. tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
  188. }
  189. //if (cute::thread0()) { print_tensor(tOrO); }
  190. Tensor rO = flash::convert_type<Element>(tOrO);
  191. // Write to gO
  192. #pragma unroll
  193. for (int m = 0; m < size<1>(rO); ++m) {
  194. const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
  195. //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
  196. if (idx < params.b * params.h * params.seqlen_q) {
  197. //print ("final2\n");
  198. const int batch_idx = idx / (params.h * params.seqlen_q);
  199. const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
  200. // The index to the rows of Q
  201. const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
  202. auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
  203. + head_idx * params.o_head_stride + row * params.o_row_stride;
  204. #pragma unroll
  205. for (int k = 0; k < size<2>(rO); ++k) {
  206. if (Is_even_K || tOpOaccum(k)) {
  207. const int col = get<1>(tOcOaccum(0, m, k));
  208. Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
  209. Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
  210. // TODO: Should check if this is using vectorized store, but it seems pretty fast
  211. copy(rO(_, m, k), gO);
  212. //if (cute::thread0()) { print ("final\n"); print_tensor(gO); }
  213. // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
  214. // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
  215. }
  216. }
  217. }
  218. }
  219. }
  220. } // namespace flash