فهرست منبع

Implement cache_leftpad

Tri Dao 8 ماه پیش
والد
کامیت
40e534a7f6

+ 21 - 0
csrc/flash_attn/flash_api.cpp

@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                const at::Tensor &cu_seqlens_q,  // b+1
                const at::Tensor &cu_seqlens_k,  // b+1
                c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
+               c10::optional<const at::Tensor> &leftpad_k_, // batch_size
                c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
                c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
                int max_seqlen_q,
@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                            head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
     }
 
+    if (leftpad_k_.has_value()) {
+        auto leftpad_k = leftpad_k_.value();
+        TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
+        TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
+        CHECK_DEVICE(leftpad_k);
+        CHECK_CONTIGUOUS(leftpad_k);
+        CHECK_SHAPE(leftpad_k, batch_size);
+        params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
+    }
+
     // number of times random will be generated per thread, to offset philox counter in thc random
     // state
     // We use a custom RNG that increases the offset by batch_size * nheads * 32.
@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                 c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
+                c10::optional<const at::Tensor> &leftpad_k_, // batch_size
                 c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
                 c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
                 c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
     }
     params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
+    if (leftpad_k_.has_value()) {
+        TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
+        auto leftpad_k = leftpad_k_.value();
+        TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
+        CHECK_DEVICE(leftpad_k);
+        CHECK_CONTIGUOUS(leftpad_k);
+        CHECK_SHAPE(leftpad_k, batch_size);
+        params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
+    }
 
     if (rotary_cos_.has_value()) {
         TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");

+ 5 - 3
csrc/flash_attn/src/block_info.h

@@ -18,8 +18,9 @@ struct BlockInfo {
         , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
         // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
         // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
-        , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
-        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
+        , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
+        , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
+        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
         {
         }
 
@@ -30,13 +31,14 @@ struct BlockInfo {
 
     template <typename index_t>
     __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
-        return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+        return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
     }
 
     const int sum_s_q;
     const int sum_s_k;
     const int actual_seqlen_q;
     // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+    const int leftpad_k;
     const int seqlen_k_cache;
     const int actual_seqlen_k;
 };

+ 1 - 0
csrc/flash_attn/src/flash.h

@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
     // array of length b+1 holding starting offset of each sequence.
     int * __restrict__ cu_seqlens_q;
     int * __restrict__ cu_seqlens_k;
+    int * __restrict__ leftpad_k;
 
     // If provided, the actual length of each k sequence.
     int * __restrict__ seqused_k;

+ 6 - 4
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
         // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
         // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
-        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
                                   Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
                                   make_stride(params.rotary_dim / 2, _1{}));
@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // if (cute::thread(8, 0)) { print_tensor(gCos); }
         // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
 
-        const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+        // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+        const index_t row_offset_knew = bidb * params.knew_batch_stride
             + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
-        const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+        // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+        const index_t row_offset_vnew = bidb * params.vnew_batch_stride
             + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
         // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
         // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
                                            binfo.actual_seqlen_q - m_block * kBlockM);
     } else {
-        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
         // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
         // We do this by setting the row stride of gCos / gSin to 0.
         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),

+ 7 - 2
flash_attn/flash_attn_interface.py

@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward(
     softcap,
     alibi_slopes,
     return_softmax,
-    block_table,
+    block_table=None,
+    leftpad_k=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward(
         cu_seqlens_q,
         cu_seqlens_k,
         None,
+        leftpad_k,
         block_table,
         alibi_slopes,
         max_seqlen_q,
@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache(
     rotary_sin=None,
     cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
     cache_batch_idx: Optional[torch.Tensor] = None,
+    cache_leftpad: Optional[torch.Tensor] = None,
     block_table: Optional[torch.Tensor] = None,
     softmax_scale=None,
     causal=False,
@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache(
         rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
         cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
             KV cache.
-        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
         cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
             If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
             If the indices are not distinct, and k and v are provided, the values updated in the cache
                  might come from any of the duplicate indices.
+        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
+        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
         softmax_scale: float. The scaling of QK^T before applying softmax.
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache(
         rotary_cos,
         rotary_sin,
         cache_batch_idx,
+        cache_leftpad,
         block_table,
         alibi_slopes,
         None,

+ 30 - 3
tests/test_flash_attn.py

@@ -182,9 +182,14 @@ def construct_local_mask(
     query_padding_mask=None,
     key_padding_mask=None,
     device=None,
+    key_leftpad=None,
 ):
     row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
     col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
+    if key_leftpad is not None:
+        key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
+        col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
+        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
     sk = (
         seqlen_k
         if key_padding_mask is None
@@ -219,6 +224,7 @@ def attention_ref(
     softcap=0.0,
     upcast=True,
     reorder_ops=False,
+    key_leftpad=None,
 ):
     """
     Arguments:
@@ -268,6 +274,7 @@ def attention_ref(
             query_padding_mask,
             key_padding_mask,
             q.device,
+            key_leftpad=key_leftpad,
         )
         scores.masked_fill_(local_mask, float("-inf"))
     if attn_bias is not None:
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
     softcap=0.0,
     upcast=True,
     reorder_ops=False,
+    key_leftpad=None,
 ):
     return attention_ref(
         q,
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
         window_size=window_size,
         softcap=softcap,
         reorder_ops=reorder_ops,
+        key_leftpad=key_leftpad,
     )
 
 
@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
 # @pytest.mark.parametrize("rotary_fraction", [0.0])
 @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
 # @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
-# @pytest.mark.parametrize("paged_kv_block_size", [256])
-@pytest.mark.parametrize("has_batch_idx", [False, True])
-# @pytest.mark.parametrize("has_batch_idx", [False])
+# @pytest.mark.parametrize("paged_kv_block_size", [None])
+@pytest.mark.parametrize("has_leftpad", [False, True])
+# @pytest.mark.parametrize("has_leftpad", [True])
+# @pytest.mark.parametrize("has_batch_idx", [False, True])
+@pytest.mark.parametrize("has_batch_idx", [False])
 @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
     seqlen_k,
     d,
     has_batch_idx,
+    has_leftpad,
     paged_kv_block_size,
     rotary_fraction,
     rotary_interleaved,
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
         pytest.skip()
     if has_batch_idx and paged_kv_block_size is not None:
         pytest.skip()
+    if has_leftpad and paged_kv_block_size is not None:
+        pytest.skip()
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
         dtype=torch.int32,
         device=device,
     )
+    if has_leftpad:
+        cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
+                                   if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
+                                   for i in range(batch_size)])
+    else:
+        cache_leftpad = None
     arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
     cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
     key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
+    if has_leftpad:
+        key_padding_mask = torch.logical_and(
+            key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
+        )
     if has_batch_idx:
         cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
             :batch_size
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
         rotary_sin=sin,
         cache_seqlens=cache_seqlens,
         cache_batch_idx=cache_batch_idx,
+        cache_leftpad=cache_leftpad,
         block_table=block_table,
         causal=causal,
         window_size=window_size,
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
         None,
         causal=causal,
         window_size=window_size,
+        key_leftpad=cache_leftpad,
     )
     out_pt, _ = attention_ref(
         q_ro,
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
         window_size=window_size,
         upcast=False,
         reorder_ops=True,
+        key_leftpad=cache_leftpad,
     )
     print(f"Output max diff: {(out - out_ref).abs().max().item()}")
     print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")