1
0

kernel_traits.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/tensor.hpp"
  6. #include "cutlass/cutlass.h"
  7. #include "cutlass/layout/layout.h"
  8. #include <cutlass/numeric_types.h>
  9. using namespace cute;
  10. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
  11. struct Flash_kernel_traits {
  12. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  13. using Element = elem_type;
  14. static constexpr bool Has_cp_async = true;
  15. #else
  16. using Element = cutlass::half_t;
  17. static constexpr bool Has_cp_async = false;
  18. #endif
  19. using ElementAccum = float;
  20. using index_t = int64_t;
  21. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  22. using MMA_Atom_Arch = std::conditional_t<
  23. std::is_same_v<elem_type, cutlass::half_t>,
  24. MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
  25. MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
  26. >;
  27. #else
  28. using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
  29. #endif
  30. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
  31. using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
  32. using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
  33. #else
  34. using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
  35. using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
  36. #endif
  37. };
  38. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  39. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
  40. typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
  41. struct Flash_fwd_kernel_traits : public Base {
  42. using Element = typename Base::Element;
  43. using ElementAccum = typename Base::ElementAccum;
  44. using index_t = typename Base::index_t;
  45. static constexpr bool Has_cp_async = Base::Has_cp_async;
  46. using SmemCopyAtom = typename Base::SmemCopyAtom;
  47. using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
  48. static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
  49. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
  50. // The number of threads.
  51. static constexpr int kNWarps = kNWarps_;
  52. static constexpr int kNThreads = kNWarps * 32;
  53. static constexpr int kBlockM = kBlockM_;
  54. static constexpr int kBlockN = kBlockN_;
  55. static constexpr int kHeadDim = kHeadDim_;
  56. static_assert(kHeadDim % 32 == 0);
  57. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  58. static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
  59. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  60. using TiledMma = TiledMMA<
  61. typename Base::MMA_Atom_Arch,
  62. Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
  63. Tile<Int<16 * kNWarps>, _16, _16>>;
  64. using SmemLayoutAtomQ = decltype(
  65. composition(Swizzle<kSwizzle, 3, 3>{},
  66. // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
  67. Layout<Shape<_8, Int<kBlockKSmem>>,
  68. Stride<Int<kBlockKSmem>, _1>>{}));
  69. using SmemLayoutQ = decltype(tile_to_shape(
  70. SmemLayoutAtomQ{},
  71. Shape<Int<kBlockM>, Int<kHeadDim>>{}));
  72. using SmemLayoutKV = decltype(tile_to_shape(
  73. SmemLayoutAtomQ{},
  74. Shape<Int<kBlockN>, Int<kHeadDim>>{}));
  75. // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
  76. using SmemLayoutVtransposed = decltype(
  77. composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
  78. using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
  79. using SmemLayoutAtomO = decltype(
  80. composition(Swizzle<kSwizzle, 3, 3>{},
  81. Layout<Shape<Int<8>, Int<kBlockKSmem>>,
  82. Stride<Int<kBlockKSmem>, _1>>{}));
  83. using SmemLayoutO = decltype(tile_to_shape(
  84. SmemLayoutAtomO{},
  85. Shape<Int<kBlockM>, Int<kHeadDim>>{}));
  86. using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
  87. using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
  88. static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
  89. static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
  90. static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
  91. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  92. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  93. // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
  94. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
  95. // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
  96. // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
  97. // to the same banks.
  98. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  99. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  100. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  101. Stride<Int<kGmemThreadsPerRow>, _1>>;
  102. // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
  103. // from the same address by the same threadblock. This is slightly faster.
  104. using Gmem_copy_struct = std::conditional_t<
  105. Has_cp_async,
  106. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  107. DefaultCopy
  108. >;
  109. using GmemTiledCopyQKV = decltype(
  110. make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
  111. GmemLayoutAtom{},
  112. Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
  113. // from how many rows does each thread have to fetch
  114. static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
  115. // Here we assign a contiguous tile to each thread, rather than a 1x8 row every
  116. // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
  117. // do not cross a page boundary. This way, each thread need only fetch 1 page index per
  118. // mainloop iteration. R>udimentary testing shows no slowdown.
  119. using GmemTiledCopyQKVPaged = decltype(
  120. make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
  121. GmemLayoutAtom{},
  122. Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
  123. using GmemTiledCopyO = decltype(
  124. make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  125. GmemLayoutAtom{},
  126. Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
  127. using GmemLayoutAtomOaccum = std::conditional_t<
  128. kBlockKSmem == 32,
  129. Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
  130. Stride< _8, _1>>,
  131. Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
  132. Stride< _16, _1>>
  133. >;
  134. using GmemTiledCopyOaccum = decltype(
  135. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  136. GmemLayoutAtomOaccum{},
  137. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  138. using GmemLayoutAtomRotcossin = GmemLayoutAtom;
  139. using GmemTiledCopyRotcossin = decltype(
  140. make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
  141. GmemLayoutAtomRotcossin{},
  142. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
  143. using GmemTiledCopyRotcossinCont = decltype(
  144. make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  145. GmemLayoutAtomRotcossin{},
  146. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
  147. using GmemTiledCopyRotcossinPaged = decltype(
  148. make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
  149. GmemLayoutAtomRotcossin{},
  150. Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
  151. using GmemTiledCopyRotcossinContPaged = decltype(
  152. make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
  153. GmemLayoutAtomRotcossin{},
  154. Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
  155. };
  156. ////////////////////////////////////////////////////////////////////////////////////////////////////