Browse Source

Write zero to out / grad if seqlen_q or seqlen_k is zero

Tri Dao 1 year ago
parent
commit
db2f80692c

+ 16 - 3
csrc/flash_attn/flash_api.cpp

@@ -405,8 +405,14 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         params.philox_args = gen->philox_cuda_state(counter_offset);
     }
 
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    run_mha_fwd(params, stream);
+    if (seqlen_k > 0) {
+        auto stream = at::cuda::getCurrentCUDAStream().stream();
+        run_mha_fwd(params, stream);
+    } else {
+        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
+        out.zero_();
+        softmax_lse.fill_(std::numeric_limits<float>::infinity());
+    }
 
     at::Tensor out_padded = out;
     if (head_size_og % 8 != 0) {
@@ -794,7 +800,14 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         params.rng_state[1] = std::get<1>(seeds);
     }
 
-    launch(params, stream, /*configure=*/false);
+    if (seqlen_q > 0) {
+        launch(params, stream, /*configure=*/false);
+    } else {
+        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
+        dk.zero_();
+        dv.zero_();
+        softmax_d.zero_();
+    }
 
     // For MQA/GQA we need to sum dK and dV across the groups
     if (num_heads_k != num_heads) {

+ 3 - 2
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -444,7 +444,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
 
     const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
-    if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
+    if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
 
     int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
     if (Is_local) {
@@ -672,7 +672,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     // We might need to exit early and write 0 to dK and dV for those blocks.
     // Otherwise we get wrong result for the case where we don't enter the for loop.
     // And we might read OOB elements from gQ and gdO.
-    if (Is_local && m_block < m_block_min) {
+    // This also covers the case where actual_seqlen_q == 0
+    if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
         const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
           + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
         const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)

+ 43 - 43
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -91,7 +91,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
 
     const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
-    if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+    if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
 
     const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
     int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
@@ -101,50 +101,50 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
         // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
         //     printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
         // }
-        // We exit early and write 0 to gO and gLSE.
-        // Otherwise we might read OOB elements from gK and gV.
-        if (n_block_max <= n_block_min) {
-            // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
-            // exit early and no one saves the rng state.
-            if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
-                auto seeds = at::cuda::philox::unpack(params.philox_args);
-                params.rng_state[0] = std::get<0>(seeds);
-                params.rng_state[1] = std::get<1>(seeds);
-            }
-            const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
-                + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
-            const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
-            Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
-                                    Shape<Int<kBlockM>, Int<kHeadDim>>{},
-                                    make_stride(params.o_row_stride, _1{}));
-            Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
-                                      Shape<Int<kBlockM>>{}, Stride<_1>{});
-
-            typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
-            auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
-            Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
-            Tensor tOrO = make_tensor<Element>(shape(tOgO));
-            clear(tOrO);
-            // Construct identity layout for sO
-            Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
-            // Repeat the partitioning with identity layouts
-            Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
-            Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
-            if (!Is_even_K) {
-                #pragma unroll
-                for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
-            }
-            // Clear_OOB_K must be false since we don't want to write zeros to gmem
-            flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
-                gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
-            );
+    }
+    // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
+    // Otherwise we might read OOB elements from gK and gV.
+    if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
+        // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
+        // exit early and no one saves the rng state.
+        if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
+            auto seeds = at::cuda::philox::unpack(params.philox_args);
+            params.rng_state[0] = std::get<0>(seeds);
+            params.rng_state[1] = std::get<1>(seeds);
+        }
+        const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+            + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+        const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+        Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
+                                Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                make_stride(params.o_row_stride, _1{}));
+        Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
+                                  Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+        typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+        Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+        Tensor tOrO = make_tensor<Element>(shape(tOgO));
+        clear(tOrO);
+        // Construct identity layout for sO
+        Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
+        // Repeat the partitioning with identity layouts
+        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+        if (!Is_even_K) {
             #pragma unroll
-            for (int m = 0; m < size<1>(tOgO); ++m) {
-                const int row = get<0>(tOcO(0, m, 0));
-                if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
-            }
-            return;
+            for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
         }
+        // Clear_OOB_K must be false since we don't want to write zeros to gmem
+        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+        );
+        #pragma unroll
+        for (int m = 0; m < size<1>(tOgO); ++m) {
+            const int row = get<0>(tOcO(0, m, 0));
+            if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
+        }
+        return;
     }
     // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }