Browse Source

Add custom ops for compatibility with PT Compile (#1139)

* Add custom ops for compatibility with PT Compile

* Add support for varlen functions too

* Add version checks for pytorch API

* Fix PT compile interfaces so it works e2e

* Make sure PT < 2.4 runs fine

* Fix python mistake

* Fix all the autograd magic issues

* typo on head_dim

* Fix deterministic test failures, remove unneeded detaches()

* remove test requires_grad

* Resolve all the pytorch versioning issues

* C++ and python refactor to improve padding management for torch.compile()

* Add improvements suggested by @anijain2305
Antoni Viros 6 months ago
parent
commit
83e41b3ca4
2 changed files with 450 additions and 228 deletions
  1. 50 111
      csrc/flash_attn/flash_api.cpp
  2. 400 117
      flash_attn/flash_attn_interface.py

+ 50 - 111
csrc/flash_attn/flash_api.cpp

@@ -343,10 +343,10 @@ void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi
 }
 
 std::vector<at::Tensor>
-mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
-        const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
-        const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x head_size
-        c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
+mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
+        const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
         c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
         const float p_dropout,
         const float softmax_scale,
@@ -385,11 +385,12 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     const int batch_size = sizes[0];
     int seqlen_q = sizes[1];
     int num_heads = sizes[2];
-    const int head_size_og = sizes[3];
+    const int head_size = sizes[3];
     const int seqlen_k = k.size(1);
     const int num_heads_k = k.size(2);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
-    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
@@ -403,28 +404,17 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
 
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
-    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
     const int ngroups = num_heads / num_heads_k;
     if (seqlenq_ngroups_swapped) {
-        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
         seqlen_q = ngroups;
         num_heads = num_heads_k;
     }
 
-    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
-    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
-
-    at::Tensor q_padded, k_padded, v_padded;
-    if (head_size_og % 8 != 0) {
-        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        q_padded = q;
-        k_padded = k;
-        v_padded = v;
-    }
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
+    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
+    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
 
     at::Tensor out;
     if (out_.has_value()) {
@@ -432,17 +422,15 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
+        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
         if (seqlenq_ngroups_swapped) {
-            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
         }
-        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     } else {
-        out = torch::empty_like(q_padded);
+        out = torch::empty_like(q);
     }
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size = round_multiple(head_size_og, 8);
     const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
@@ -460,6 +448,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
         p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
     }
+    else {
+        p = torch::empty({ 0 }, opts);
+    }
 
     Flash_fwd_params params;
     set_params_fprop(params,
@@ -468,7 +459,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      seqlen_q_rounded, seqlen_k_rounded,
                      num_heads, num_heads_k,
                      head_size, head_size_rounded,
-                     q_padded, k_padded, v_padded, out,
+                     q, k, v, out,
                      /*cu_seqlens_q_d=*/nullptr,
                      /*cu_seqlens_k_d=*/nullptr,
                      /*seqused_k=*/nullptr,
@@ -515,19 +506,12 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         softmax_lse.fill_(std::numeric_limits<float>::infinity());
     }
 
-    at::Tensor out_padded = out;
-    if (head_size_og % 8 != 0) {
-        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        if (out_.has_value()) { out_.value().copy_(out); }
-    }
-
     if (seqlenq_ngroups_swapped) {
-        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
-        out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
-        q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
+        q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
         softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
     }
-    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+    return {out, softmax_lse, p, rng_state};
 }
 
 std::vector<at::Tensor>
@@ -595,7 +579,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
 
     const int batch_size = cu_seqlens_q.numel() - 1;
     int num_heads = sizes[1];
-    const int head_size_og = sizes[2];
+    const int head_size = sizes[2];
     const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
 
     if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
@@ -612,10 +596,10 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
 
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
-    const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
     const int ngroups = num_heads / num_heads_k;
     if (seqlenq_ngroups_swapped) {
-        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
         max_seqlen_q = ngroups;
         num_heads = num_heads_k;
         cu_seqlens_q_d = nullptr;
@@ -624,20 +608,21 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     const int total_q = q.sizes()[0];
 
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
-    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
 
-    CHECK_SHAPE(q, total_q, num_heads, head_size_og);
+    CHECK_SHAPE(q, total_q, num_heads, head_size);
     if (!paged_KV) {
         const int total_k = k.size(0);
-        CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
-        CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+        CHECK_SHAPE(k, total_k, num_heads_k, head_size);
+        CHECK_SHAPE(v, total_k, num_heads_k, head_size);
     } else {
-        CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
-        CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
+        CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
         CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
     }
 
@@ -651,34 +636,21 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         CHECK_SHAPE(seqused_k_, batch_size);
     }
 
-    at::Tensor q_padded, k_padded, v_padded;
-    if (head_size_og % 8 != 0) {
-        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        q_padded = q;
-        k_padded = k;
-        v_padded = v;
-    }
-
     at::Tensor out;
     if (out_.has_value()) {
         out = out_.value();
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
+        CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
         if (seqlenq_ngroups_swapped) {
-            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
         }
-        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     } else {
-        out = torch::empty_like(q_padded);
+        out = torch::empty_like(q);
     }
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size = round_multiple(head_size_og, 8);
     const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
@@ -695,6 +667,9 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
         p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
     }
+    else {
+        p = torch::empty({ 0 }, opts);
+    }
 
     if (zero_tensors) {
         out.zero_();
@@ -709,7 +684,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                      seqlen_q_rounded, seqlen_k_rounded,
                      num_heads, num_heads_k,
                      head_size, head_size_rounded,
-                     q_padded, k_padded, v_padded, out,
+                     q, k, v, out,
                      cu_seqlens_q_d,
                      cu_seqlens_k.data_ptr(),
                      seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
@@ -727,8 +702,8 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     if (paged_KV) {
         params.block_table = block_table.data_ptr<int>();
         params.block_table_batch_stride = block_table.stride(0);
-        params.k_batch_stride = k_padded.stride(0);
-        params.v_batch_stride = v_padded.stride(0);
+        params.k_batch_stride = k.stride(0);
+        params.v_batch_stride = v.stride(0);
     }
     params.page_block_size = page_block_size;
     // Keep references to these tensors to extend their lifetime
@@ -779,22 +754,15 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         softmax_lse.fill_(std::numeric_limits<float>::infinity());
     }
 
-    at::Tensor out_padded = out;
-    if (head_size_og % 8 != 0) {
-        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        if (out_.has_value()) { out_.value().copy_(out); }
-    }
-
     if (seqlenq_ngroups_swapped) {
-        int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
-        int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
+        int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
+        int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
         out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
-        out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
-        q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
+        q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
         softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
     }
 
-    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+    return {out, softmax_lse, p, rng_state};
 }
 
 void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
@@ -808,7 +776,7 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
 }
 
 std::vector<at::Tensor>
-mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_size_og
+mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
         const at::Tensor &q,   // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &k,   // batch_size x seqlen_k x num_heads_k x head_size
         const at::Tensor &v,   // batch_size x seqlen_k x num_heads_k x head_size
@@ -869,7 +837,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     const int batch_size = sizes[0];
     const int seqlen_q = sizes[1];
     const int num_heads = sizes[2];
-    const int head_size_og = dout.size(3);
     const int head_size = sizes[3];
     const int seqlen_k = k.size(1);
     const int num_heads_k = k.size(2);
@@ -886,7 +853,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 
-    TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
     if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
 
     if (window_size_left >= seqlen_k) { window_size_left = -1; }
@@ -896,7 +862,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
     CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
     CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
-    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
+    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
 
     at::Tensor dq, dk, dv;
     if (dq_.has_value()) {
@@ -927,13 +893,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         dv = torch::empty_like(v);
     }
 
-    at::Tensor dout_padded;
-    if (head_size_og % 8 != 0) {
-        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        dout_padded = dout;
-    }
-
     // bool loop = seqlen_k > blocksize_c;
     // TODO: change later, for now set to true for simplicity
     bool loop = true;
@@ -975,7 +934,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      num_heads, num_heads_k,
                      head_size, head_size_rounded,
                      q, k, v, out,
-                     dout_padded, dq, dk_expanded, dv_expanded,
+                     dout, dq, dk_expanded, dv_expanded,
                      nullptr,
                      nullptr,
                      loop ? dq_accum.data_ptr() : nullptr,
@@ -1029,11 +988,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
         at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
     }
-    if (head_size_og % 8 != 0) {
-        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-    }
 
     return { dq, dk, dv, softmax_d };
 }
@@ -1110,7 +1064,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     const int total_q = sizes[0];
     const int batch_size = cu_seqlens_q.numel() - 1;
     const int num_heads = sizes[1];
-    const int head_size_og = dout.size(2);
     const int head_size = sizes[2];
     const int total_k = k.size(0);
     const int num_heads_k = k.size(1);
@@ -1128,8 +1081,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
 
-    TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
-
     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
 
@@ -1137,7 +1088,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     CHECK_SHAPE(k, total_k, num_heads_k, head_size);
     CHECK_SHAPE(v, total_k, num_heads_k, head_size);
     CHECK_SHAPE(out, total_q, num_heads, head_size);
-    CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
+    CHECK_SHAPE(dout, total_q, num_heads, head_size);
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
 
@@ -1170,13 +1121,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         dv = torch::empty_like(v);
     }
 
-    at::Tensor dout_padded;
-    if (head_size_og % 8 != 0) {
-        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        dout_padded = dout;
-    }
-
     // bool loop = max_seqlen_k > blocksize_c;
     // TODO: change later, for now set to true for simplicity
     bool loop = true;
@@ -1231,7 +1175,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                      num_heads, num_heads_k,
                      head_size, head_size_rounded,
                      q, k, v, out,
-                     dout_padded, dq, dk_expanded, dv_expanded,
+                     dout, dq, dk_expanded, dv_expanded,
                      cu_seqlens_q.data_ptr(),
                      cu_seqlens_k.data_ptr(),
                      loop ? dq_accum.data_ptr() : nullptr,
@@ -1284,11 +1228,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
         at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
     }
-    if (head_size_og % 8 != 0) {
-        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-    }
 
     return { dq, dk, dv, softmax_d };
 }

+ 400 - 117
flash_attn/flash_attn_interface.py

@@ -1,6 +1,6 @@
 # Copyright (c) 2023, Tri Dao.
 
-from typing import Optional, Union
+from typing import Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -14,6 +14,7 @@ import flash_attn_2_cuda as flash_attn_cuda
 def maybe_contiguous(x):
     return x.contiguous() if x is not None and x.stride(-1) != 1 else x
 
+
 def _get_block_size_n(device, head_dim, is_dropout, is_causal):
     # This should match the block sizes in the CUDA kernel
     assert head_dim <= 256
@@ -45,11 +46,49 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
         return 64
 
 
+def round_multiple(x, m):
+    return (x + m - 1) // m * m
+
+
+# torch.compile() support is only enabled for pytorch >= 2.4
+# The reason for this is that we are using the new custom_op and register_fake
+# APIs, which support inplace modification of inputs in the function itself
+if torch.__version__ >= "2.4.0":
+    _torch_custom_op_wrapper = torch.library.custom_op
+    _torch_register_fake_wrapper = torch.library.register_fake
+else:
+    def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
+        def wrap(func):
+            return func
+        if fn is None:
+            return wrap
+        return fn
+    def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
+        def wrap(func):
+            return func
+        if fn is None:
+            return wrap
+        return fn
+    _torch_custom_op_wrapper = noop_custom_op_wrapper
+    _torch_register_fake_wrapper = noop_register_fake_wrapper
+
+
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
 def _flash_attn_forward(
-    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
-):
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    return_softmax: bool
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
-    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
+    out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
         q,
         k,
         v,
@@ -58,36 +97,71 @@ def _flash_attn_forward(
         dropout_p,
         softmax_scale,
         causal,
-        window_size[0],
-        window_size[1],
+        window_size_left,
+        window_size_right,
         softcap,
         return_softmax,
         None,
     )
-    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
+    return out, softmax_lse, S_dmask, rng_state
+
+
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
+def _flash_attn_forward_fake(
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    return_softmax: bool
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    batch_size, seqlen_q, num_heads, head_size = q.shape
+    seqlen_k = k.shape[1]
+    out = torch.empty_like(q)
+    softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
+    p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
+    if return_softmax:
+        p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
+    rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
+
+    return out, softmax_lse, p, rng_state
+
+
+if torch.__version__ >= "2.4.0":
+    _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
+else:
+    _wrapped_flash_attn_forward = _flash_attn_forward
 
 
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
 def _flash_attn_varlen_forward(
-    q,
-    k,
-    v,
-    cu_seqlens_q,
-    cu_seqlens_k,
-    max_seqlen_q,
-    max_seqlen_k,
-    dropout_p,
-    softmax_scale,
-    causal,
-    window_size=(-1, -1),
-    softcap=0.0,
-    alibi_slopes=None,
-    return_softmax=False,
-    block_table=None,
-    leftpad_k=None,
-    seqused_k=None,
-):
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    cu_seqlens_q: torch.Tensor,
+    cu_seqlens_k: torch.Tensor,
+    max_seqlen_q: int,
+    max_seqlen_k: int,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int = -1,
+    window_size_right: int = -1,
+    softcap: float = 0.0,
+    alibi_slopes: Optional[torch.Tensor] = None,
+    return_softmax: bool = False,
+    block_table: Optional[torch.Tensor] = None,
+    leftpad_k: Optional[torch.Tensor] = None,
+    seqused_k: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
-    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
+    out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
         q,
         k,
         v,
@@ -104,36 +178,81 @@ def _flash_attn_varlen_forward(
         softmax_scale,
         False,
         causal,
-        window_size[0],
-        window_size[1],
+        window_size_left,
+        window_size_right,
         softcap,
         return_softmax,
         None,
     )
     # if out.isnan().any() or softmax_lse.isnan().any():
     #     breakpoint()
-    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
-
-
+    return out, softmax_lse, S_dmask, rng_state
+
+
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
+def _flash_attn_varlen_forward_fake(
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    cu_seqlens_q: torch.Tensor,
+    cu_seqlens_k: torch.Tensor,
+    max_seqlen_q: int,
+    max_seqlen_k: int,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int = -1,
+    window_size_right: int = -1,
+    softcap: float = 0.0,
+    alibi_slopes: Optional[torch.Tensor] = None,
+    return_softmax: bool = False,
+    block_table: Optional[torch.Tensor] = None,
+    leftpad_k: Optional[torch.Tensor] = None,
+    seqused_k: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    paged_kv = block_table is not None
+    batch_size = cu_seqlens_q.numel() - 1
+    total_q, num_heads, _ = q.shape
+    
+    out = torch.empty_like(q)
+    softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
+    p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
+    seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
+    seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
+    if return_softmax:
+        p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
+    rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
+    return out, softmax_lse, p, rng_state
+
+
+if torch.__version__ >= "2.4.0":
+    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
+else:
+    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
+
+
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
 def _flash_attn_backward(
-    dout,
-    q,
-    k,
-    v,
-    out,
-    softmax_lse,
-    dq,
-    dk,
-    dv,
-    dropout_p,
-    softmax_scale,
-    causal,
-    window_size,
-    softcap,
-    alibi_slopes,
-    deterministic,
-    rng_state=None,
-):
+    dout: torch.Tensor,
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    out: torch.Tensor,
+    softmax_lse: torch.Tensor,
+    dq: Optional[torch.Tensor],
+    dk: Optional[torch.Tensor],
+    dv: Optional[torch.Tensor],
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    deterministic: bool,
+    rng_state: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
     # dq, dk, dv are allocated by us so they should already be contiguous
     dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
     (
@@ -155,39 +274,81 @@ def _flash_attn_backward(
         dropout_p,
         softmax_scale,
         causal,
-        window_size[0],
-        window_size[1],
+        window_size_left,
+        window_size_right,
         softcap,
         deterministic,
         None,
         rng_state,
     )
-    return dq, dk, dv, softmax_d
+    return softmax_d
+
+
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
+def _flash_attn_backward_fake(
+    dout: torch.Tensor,
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    out: torch.Tensor,
+    softmax_lse: torch.Tensor,
+    dq: Optional[torch.Tensor],
+    dk: Optional[torch.Tensor],
+    dv: Optional[torch.Tensor],
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    deterministic: bool,
+    rng_state: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
+    if dq is None:
+        dq = torch.empty_like(q)
+    if dk is None:
+        dk = torch.empty_like(k)
+    if dv is None:
+        dv = torch.empty_like(v)
+    batch_size, seqlen_q, num_heads, _ = q.shape
+    softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
+    
+    return softmax_d
+
 
+if torch.__version__ >= "2.4.0":
+    _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
+else:
+    _wrapped_flash_attn_backward = _flash_attn_backward
 
+
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
 def _flash_attn_varlen_backward(
-    dout,
-    q,
-    k,
-    v,
-    out,
-    softmax_lse,
-    dq,
-    dk,
-    dv,
-    cu_seqlens_q,
-    cu_seqlens_k,
-    max_seqlen_q,
-    max_seqlen_k,
-    dropout_p,
-    softmax_scale,
-    causal,
-    window_size,
-    softcap,
-    alibi_slopes,
-    deterministic,
-    rng_state=None,
-):
+    dout: torch.Tensor,
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    out: torch.Tensor,
+    softmax_lse: torch.Tensor,
+    dq: Optional[torch.Tensor],
+    dk: Optional[torch.Tensor],
+    dv: Optional[torch.Tensor],
+    cu_seqlens_q: torch.Tensor,
+    cu_seqlens_k: torch.Tensor,
+    max_seqlen_q: int,
+    max_seqlen_k: int,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    deterministic: bool,
+    rng_state: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
     # dq, dk, dv are allocated by us so they should already be contiguous
     dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
     (
@@ -214,8 +375,8 @@ def _flash_attn_varlen_backward(
         softmax_scale,
         False,
         causal,
-        window_size[0],
-        window_size[1],
+        window_size_left,
+        window_size_right,
         softcap,
         deterministic,
         None,
@@ -223,7 +384,53 @@ def _flash_attn_varlen_backward(
     )
     # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
     #     breakpoint()
-    return dq, dk, dv, softmax_d
+    return softmax_d
+
+
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
+def _flash_attn_varlen_backward_fake(
+    dout: torch.Tensor,
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    out: torch.Tensor,
+    softmax_lse: torch.Tensor,
+    dq: Optional[torch.Tensor],
+    dk: Optional[torch.Tensor],
+    dv: Optional[torch.Tensor],
+    cu_seqlens_q: torch.Tensor,
+    cu_seqlens_k: torch.Tensor,
+    max_seqlen_q: int,
+    max_seqlen_k: int,
+    dropout_p: float,
+    softmax_scale: float,
+    causal: bool,
+    window_size_left: int,
+    window_size_right: int,
+    softcap: float,
+    alibi_slopes: Optional[torch.Tensor],
+    deterministic: bool,
+    rng_state: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
+    batch_size = cu_seqlens_q.numel() - 1
+    total_q, num_heads, _ = q.shape
+
+    if dq is None:
+        dq = torch.empty_like(q)
+    if dk is None:
+        dk = torch.empty_like(k)
+    if dv is None:
+        dv = torch.empty_like(v)
+    softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
+    
+    return softmax_d
+
+
+if torch.__version__ >= "2.4.0":
+    _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
+else:
+    _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
 
 
 class FlashAttnQKVPackedFunc(torch.autograd.Function):
@@ -242,14 +449,21 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = qkv.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
-            qkv[:, :, 0],
-            qkv[:, :, 1],
-            qkv[:, :, 2],
+        q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
+        head_size_og = q.size(3)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state =  _wrapped_flash_attn_forward(
+            q,
+            k,
+            v,
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -262,6 +476,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -269,8 +484,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
         qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
         dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
-        _flash_attn_backward(
-            dout,
+        head_size_og = dout.size(3)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -282,7 +501,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
@@ -310,10 +530,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = qkv.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
-            qkv[:, 0],
-            qkv[:, 1],
-            qkv[:, 2],
+        q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
+        head_size_og = q.size(2)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
+            q,
+            k,
+            v,
             cu_seqlens,
             cu_seqlens,
             max_seqlen,
@@ -321,7 +547,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -336,6 +563,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -343,8 +571,12 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
         qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
         dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
-        _flash_attn_varlen_backward(
-            dout,
+        head_size_og = dout.size(2)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_varlen_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -360,7 +592,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
@@ -387,14 +620,21 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
+        k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
+        head_size_og = q.size(3)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
             q,
-            kv[:, :, 0],
-            kv[:, :, 1],
+            k,
+            v,
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -407,6 +647,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -415,8 +656,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
         dq = torch.empty_like(q)
         kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
         dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
-        _flash_attn_backward(
-            dout,
+        head_size_og = dout.size(3)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -428,7 +673,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
@@ -460,10 +706,16 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
+        k, v = kv[:, 0].detach(), kv[:, 1].detach()
+        head_size_og = q.size(2)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
             q,
-            kv[:, 0],
-            kv[:, 1],
+            k,
+            v,
             cu_seqlens_q,
             cu_seqlens_k,
             max_seqlen_q,
@@ -471,7 +723,8 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -489,6 +742,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -497,8 +751,12 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         dq = torch.empty_like(q)
         kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
         dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
-        _flash_attn_varlen_backward(
-            dout,
+        head_size_og = dout.size(2)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_varlen_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -514,7 +772,8 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
@@ -543,14 +802,20 @@ class FlashAttnFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
+        head_size_og = q.size(3)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
             q,
             k,
             v,
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -563,14 +828,19 @@ class FlashAttnFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
     def backward(ctx, dout, *args):
         q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
         dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
-        _flash_attn_backward(
-            dout,
+        head_size_og = dout.size(3)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -582,7 +852,8 @@ class FlashAttnFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
@@ -617,7 +888,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
-        out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
+        head_size_og = q.size(2)
+        if head_size_og % 8 != 0:
+            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
+            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
+            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
+        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
             q,
             k,
             v,
@@ -628,7 +904,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             dropout_p,
             softmax_scale,
             causal=causal,
-            window_size=window_size,
+            window_size_left=window_size[0],
+            window_size_right=window_size[1],
             softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
@@ -646,14 +923,19 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
+        out = out_padded[..., :head_size_og]
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
     def backward(ctx, dout, *args):
         q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
         dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
-        _flash_attn_varlen_backward(
-            dout,
+        head_size_og = dout.size(2)
+        dout_padded = dout
+        if head_size_og % 8 != 0:
+            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
+        _wrapped_flash_attn_varlen_backward(
+            dout_padded,
             q,
             k,
             v,
@@ -669,7 +951,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.dropout_p,
             ctx.softmax_scale,
             ctx.causal,
-            ctx.window_size,
+            ctx.window_size[0],
+            ctx.window_size[1],
             ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,