瀏覽代碼

Add ArchTag to pre/postprocess bwd kernels (#1180)

* Add ArchTag to pre/postprocess bwd kernels

* Type-dependent CC check for bwd pre/postprocess

* Fix CC >= 90 for bwd postprocess

---------

Co-authored-by: Cameron Shinn <cshinn@nvidia.com>
Cameron Shinn 6 月之前
父節點
當前提交
3cea2fb6ee
共有 3 個文件被更改,包括 12 次插入4 次删除
  1. 2 2
      hopper/flash_bwd_launch_template.h
  2. 4 1
      hopper/flash_bwd_postprocess_kernel.h
  3. 6 1
      hopper/flash_bwd_preprocess_kernel.h

+ 2 - 2
hopper/flash_bwd_launch_template.h

@@ -25,7 +25,7 @@ template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_caus
 void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
     using ElementAccum = float;
-    using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, /*Clear_dQaccum=*/true, Varlen>;
+    using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
     int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * 128, 128);
     typename PreprocessKernel::Arguments preprocess_args {
         static_cast<Element const*>(params.o_ptr),
@@ -130,7 +130,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
     CHECK_CUDA_KERNEL_LAUNCH();
 
-    using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum,
+    using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90,
         AttnKernel::CollectiveMainloop::kNThreadsdQ,
         typename AttnKernel::CollectiveMainloop::SmemLayoutdQaccumTMA,
         typename AttnKernel::CollectiveMainloop::TiledMmadQ,

+ 4 - 1
hopper/flash_bwd_postprocess_kernel.h

@@ -18,7 +18,7 @@ namespace flash {
 
 using namespace cute;
 
-template <class TileShape_MK_, class Element, class ElementAccum, int kNThreads, class SmemLayoutdQaccumTMA,
+template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class SmemLayoutdQaccumTMA,
           class TiledMma, bool dQ_swapAB>
 class FlashAttnBwdPostprocessConvertdQ {
 
@@ -26,6 +26,9 @@ public:
 
     // Type Aliases
     using TileShape_MK = TileShape_MK_;
+    using ArchTag = ArchTag_;
+
+    static_assert(ArchTag::kMinComputeCapability >= 90);
 
     static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
     static constexpr uint32_t MinBlocksPerMultiprocessor = 2;

+ 6 - 1
hopper/flash_bwd_preprocess_kernel.h

@@ -17,13 +17,18 @@ namespace flash {
 
 using namespace cute;
 
-template <class TileShape_MK_, class Element, class ElementAccum, bool Clear_dQaccum, bool Varlen>
+template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
 class FlashAttnBwdPreprocess {
 
 public:
 
     // Type Aliases
     using TileShape_MK = TileShape_MK_;
+    using ArchTag = ArchTag_;
+
+    static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
+                  std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
+                  std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
 
     static constexpr uint32_t MaxThreadsPerBlock = 256;
     static constexpr uint32_t MinBlocksPerMultiprocessor = 2;