flash_bwd_preprocess_kernel.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include <cutlass/cutlass.h>
  7. #include <cutlass/array.h>
  8. #include <cutlass/numeric_types.h>
  9. #include <cutlass/numeric_conversion.h>
  10. #include "seqlen.h"
  11. #include "utils.h"
  12. namespace flash {
  13. using namespace cute;
  14. template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
  15. class FlashAttnBwdPreprocess {
  16. public:
  17. // Type Aliases
  18. using TileShape_MK = TileShape_MK_;
  19. using ArchTag = ArchTag_;
  20. static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
  21. std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
  22. std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
  23. static constexpr uint32_t MaxThreadsPerBlock = 256;
  24. static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
  25. static constexpr int SharedStorageSize = 0;
  26. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  27. static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
  28. static constexpr int kBlockM = get<0>(TileShape_MK{});
  29. static constexpr int kHeadDim = get<1>(TileShape_MK{});
  30. // We want kBlockKGmem to be a power of 2 so that when we do the summing,
  31. // it's just between threads in the same warp
  32. static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
  33. static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
  34. static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
  35. using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  36. Stride<Int<kGmemThreadsPerRow>, _1>>;
  37. using GmemTiledCopy = decltype(
  38. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
  39. GmemLayoutAtom{},
  40. Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
  41. static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
  42. static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
  43. using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
  44. using GmemTiledCopyAccum = decltype(
  45. make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
  46. GmemLayoutAtomAccum{},
  47. Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
  48. using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
  49. using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
  50. using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
  51. using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
  52. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  53. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  54. // Device side arguments
  55. struct Arguments {
  56. Element const* ptr_O;
  57. ShapeO const shape_O;
  58. StrideO const stride_O;
  59. Element const* ptr_dO;
  60. StrideO const stride_dO;
  61. float* ptr_dPsum;
  62. ShapedPsum const shape_dPsum;
  63. StridedPsum const stride_dPsum;
  64. float const* ptr_LSE;
  65. StridedPsum const stride_LSE;
  66. float *ptr_LSE_log2;
  67. StridedPsum const stride_LSE_log2;
  68. ElementAccum* ptr_dQaccum;
  69. ShapedQaccum const shape_dQaccum;
  70. StridedQaccum const stride_dQaccum;
  71. int num_batch; // We need this to know the size of dq_semaphore in case of varlen
  72. int* dq_semaphore;
  73. int const* cu_seqlens = nullptr;
  74. int const* seqused = nullptr;
  75. };
  76. // Kernel entry point API
  77. struct Params {
  78. Element const* ptr_O;
  79. ShapeO const shape_O;
  80. StrideO const stride_O;
  81. Element const* ptr_dO;
  82. StrideO const stride_dO;
  83. float* ptr_dPsum;
  84. ShapedPsum const shape_dPsum;
  85. StridedPsum const stride_dPsum;
  86. float const* ptr_LSE;
  87. StridedPsum const stride_LSE;
  88. float* ptr_LSE_log2;
  89. StridedPsum const stride_LSE_log2;
  90. ElementAccum* ptr_dQaccum;
  91. ShapedQaccum const shape_dQaccum;
  92. StridedQaccum const stride_dQaccum;
  93. int num_batch;
  94. int* dq_semaphore;
  95. int const* cu_seqlens = nullptr;
  96. int const* seqused = nullptr;
  97. };
  98. // Convert to underlying arguments. In this case, a simple copy for the aliased type.
  99. static
  100. Params
  101. to_underlying_arguments(Arguments const& args) {
  102. return {
  103. args.ptr_O,
  104. args.shape_O,
  105. args.stride_O,
  106. args.ptr_dO,
  107. args.stride_dO,
  108. args.ptr_dPsum,
  109. args.shape_dPsum,
  110. args.stride_dPsum,
  111. args.ptr_LSE,
  112. args.stride_LSE,
  113. args.ptr_LSE_log2,
  114. args.stride_LSE_log2,
  115. args.ptr_dQaccum,
  116. args.shape_dQaccum,
  117. args.stride_dQaccum,
  118. args.num_batch,
  119. args.dq_semaphore,
  120. args.cu_seqlens,
  121. args.seqused
  122. };
  123. }
  124. CUTLASS_DEVICE
  125. void
  126. operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
  127. static constexpr int kBlockM = get<0>(TileShape_MK{});
  128. int const thread_idx = threadIdx.x;
  129. int const m_block = blockIdx.x;
  130. int const bidh = blockIdx.y;
  131. int const bidb = blockIdx.z;
  132. flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
  133. bool const is_varlen = Varlen && params.cu_seqlens;
  134. int const seqlen_o = seqlen_info.seqlen;
  135. if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
  136. Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
  137. Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
  138. Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
  139. Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
  140. auto shape_LSE = select<0, 2, 3>(params.shape_O);
  141. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
  142. Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
  143. static_assert(kBlockM <= MaxThreadsPerBlock);
  144. float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
  145. GmemTiledCopy gmem_tiled_copy_O;
  146. auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
  147. Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
  148. Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
  149. // Construct identity layout for gO
  150. Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
  151. // Repeat the partitioning with identity layouts
  152. Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
  153. Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
  154. #pragma unroll
  155. for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
  156. // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
  157. Tensor tOrO = make_fragment_like(tOgO);
  158. Tensor tOrdO = make_fragment_like(tOgdO);
  159. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
  160. gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  161. );
  162. flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
  163. gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
  164. );
  165. // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
  166. // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
  167. Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
  168. Tensor tOrO_l = make_tensor(tOrO.data(), l);
  169. Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
  170. flash::convert_type_out(tOrO_l, o_fp32);
  171. Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
  172. Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
  173. flash::convert_type_out(tOrdO_l, do_fp32);
  174. // Sum across the last dimension
  175. Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
  176. #pragma unroll
  177. for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
  178. float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
  179. #pragma unroll
  180. for (int ni = 1; ni < size<1>(o_fp32); ni++) {
  181. dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
  182. }
  183. flash::SumOp<float> sum_op;
  184. dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
  185. }
  186. Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
  187. Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
  188. if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
  189. #pragma unroll
  190. for (int mi = 0; mi < size(dP_sum); ++mi) {
  191. int const row = get<0>(tOcO(_0{}, mi, _0{}));
  192. gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
  193. }
  194. }
  195. int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
  196. Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
  197. Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
  198. if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
  199. gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
  200. }
  201. if constexpr (Clear_dQaccum) {
  202. Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  203. Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
  204. GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
  205. auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  206. Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
  207. Tensor zero = make_fragment_like(tdQgdQaccum);
  208. clear(zero);
  209. cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
  210. }
  211. if (params.dq_semaphore != nullptr && thread_idx == 0) {
  212. int const num_batch = params.num_batch;
  213. int const num_head = get<2>(params.shape_O);
  214. params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
  215. }
  216. }
  217. };
  218. } // namespace flash