瀏覽代碼

Hotfix due to change of upstream api (#1239)

rocking 6 月之前
父節點
當前提交
53a4f34163

+ 12 - 12
csrc/flash_attn_ck/flash_api.cpp

@@ -5,11 +5,11 @@
 #include "flash_common.hpp"
 
 std::vector<at::Tensor>
-mha_fwd(at::Tensor &q,
-        const at::Tensor &k,
-        const at::Tensor &v,
-        c10::optional<at::Tensor> &out_,
-        c10::optional<at::Tensor> &alibi_slopes_,
+mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
+        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        c10::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
+        c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
         const float p_dropout,
         const float softmax_scale,
         bool is_causal,
@@ -43,7 +43,7 @@ mha_varlen_fwd(at::Tensor &q,                               // total_q x num_hea
                c10::optional<at::Generator> gen_);
 
 std::vector<at::Tensor>
-mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x head_size_og
+mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
         const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
         const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
@@ -113,10 +113,10 @@ mha_fwd_kvcache(at::Tensor &q,                                     // batch_size
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
 {
-    m.doc() = "FlashAttention";
-    m.def("fwd", &mha_fwd, "Forward pass");
-    m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
-    m.def("bwd", &mha_bwd, "Backward pass");
-    m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
-    m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
+        m.doc() = "FlashAttention";
+        m.def("fwd", &mha_fwd, "Forward pass");
+        m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
+        m.def("bwd", &mha_bwd, "Backward pass");
+        m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
+        m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
 }

+ 22 - 38
csrc/flash_attn_ck/mha_bwd.cpp

@@ -195,7 +195,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
 }
 
 std::vector<at::Tensor>
-mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x head_size_og
+mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x multiple_of(head_size, 8)
         const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
         const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
@@ -248,18 +248,14 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
     const int batch_size = sizes[0];
     const int seqlen_q = sizes[1];
     const int num_heads = sizes[2];
-    const int head_size_og = dout.size(3);  // unpadded hdim
-    const int head_size_8x = sizes[3];
+    const int head_size = sizes[3];
     const int seqlen_k = k.size(1);
     const int num_heads_k = k.size(2);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
-    TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
-    TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
+    TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
-    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
-
     if (window_size_left >= seqlen_k) { window_size_left = -1; }
     if (window_size_right >= seqlen_k) { window_size_right = -1; }
 
@@ -279,11 +275,11 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
 
     // q, k, v, out had been padded in mha_fwd
     // dq_, dk_, dv_ are also padded tensor
-    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
-    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
-    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
-    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
-    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
+    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
+    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
+    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
+    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
 
     at::Tensor dq, dk, dv;
     if (dq_.has_value()) {
@@ -291,7 +287,7 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
         CHECK_DEVICE(dq);
         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
-        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
+        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
     } else {
         dq = torch::empty_like(q);
     }
@@ -300,7 +296,7 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
     TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
     CHECK_DEVICE(dk);
     TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
-    CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
+    CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
     } else {
         dk = torch::empty_like(k);
     }
@@ -309,18 +305,11 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
         CHECK_DEVICE(dv);
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
-        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
+        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
     } else {
         dv = torch::empty_like(v);
     }
 
-    at::Tensor dout_padded;
-    if (head_size_og % 8 != 0) {
-        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        dout_padded = dout;
-    }
-
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
@@ -329,17 +318,17 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
     at::Tensor dq_accum;
 
     if (!deterministic) {
-        dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+        dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
     } else {
-        const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
+        const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
         const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
-        dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+        dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
     }
 
     at::Tensor dk_expanded, dv_expanded;
     if (num_heads_k != num_heads) {  // MQA / GQA
-        dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
-        dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
+        dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
+        dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
     } else {
         dk_expanded = dk;
         dv_expanded = dv;
@@ -366,7 +355,7 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
         ck_tile::stream_config stream_config{stream};
 
         auto traits =
-            get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
+            get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);
 
         auto args =
             get_ck_fmha_bwd_args(
@@ -376,14 +365,14 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
                 seqlen_k,
                 num_heads,
                 num_heads_k,
-                head_size_8x,
+                head_size,
                 q,
                 k,
                 v,
                 alibi_slopes_,
                 out,
                 softmax_lse,
-                dout_padded,
+                dout,
                 dq_accum,
                 softmax_d,
                 dq,
@@ -405,13 +394,8 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
 
     // For MQA/GQA we need to sum dK and dV across the groups
     if (num_heads_k != num_heads) {
-        at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
-        at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
-    }
-    if (head_size_og % 8 != 0) {
-        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
+        at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
     }
 
     return { dq, dk, dv, softmax_d };

+ 26 - 45
csrc/flash_attn_ck/mha_fwd.cpp

@@ -141,10 +141,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
 }
 
 std::vector<at::Tensor>
-mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x head_size
-        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
-        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
-        c10::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x head_size
+mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
+        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
+        c10::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
         c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
         const float p_dropout,
         const float softmax_scale,
@@ -175,11 +175,12 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
     const int batch_size = sizes[0];
     int seqlen_q = sizes[1];
     int num_heads = sizes[2];
-    const int head_size_og = sizes[3];
+    const int head_size = sizes[3];
     const int seqlen_k = k.size(1);
     const int num_heads_k = k.size(2);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
-    TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     if (window_size_left >= seqlen_k) { window_size_left = -1; }
@@ -206,29 +207,17 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
 
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
-    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
     const int ngroups = num_heads / num_heads_k;
     if (seqlenq_ngroups_swapped) {
-        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
         seqlen_q = ngroups;
         num_heads = num_heads_k;
     }
 
-    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
-    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
-
-    at::Tensor q_padded, k_padded, v_padded;
-    if (head_size_og % 8 != 0) {
-        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    }
-    else {
-        q_padded = q;
-        k_padded = k;
-        v_padded = v;
-    }
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
+    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
+    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
 
     at::Tensor out;
     if (out_.has_value()) {
@@ -236,19 +225,15 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
+        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
         if (seqlenq_ngroups_swapped) {
-            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
         }
-        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     }
     else {
-        out = torch::empty_like(q_padded);
+        out = torch::empty_like(q);
     }
 
-    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size_8x = round_multiple(head_size_og, 8);
-
     // Otherwise the kernel will be launched from cuda:0 device
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
@@ -266,6 +251,9 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
         TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
         p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
     }
+    else {
+        p = torch::empty({ 0 }, opts);
+    }
 
     uint64_t drop_seed = 1, drop_offset = 0;
     int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
@@ -292,7 +280,7 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
             get_ck_fmha_fwd_traits(
                 mask,
                 q_dtype_str,
-                head_size_8x,
+                head_size,
                 has_dropout,
                 has_lse,
                 alibi_slopes_.has_value());
@@ -307,10 +295,10 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
                 seqlen_k,
                 num_heads,
                 num_heads_k,
-                head_size_8x,
-                q_padded,
-                k_padded,
-                v_padded,
+                head_size,
+                q,
+                k,
+                v,
                 alibi_slopes_,
                 out,
                 softmax_lse,
@@ -329,17 +317,10 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
         softmax_lse.fill_(std::numeric_limits<float>::infinity());
     }
 
-    at::Tensor out_padded = out;
-    if (head_size_og % 8 != 0) {
-        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        if (out_.has_value()) { out_.value().copy_(out); }
-    }
-
     if (seqlenq_ngroups_swapped) {
-        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
-        out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
-        q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
+        q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
         softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
     }
-    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+    return {out, softmax_lse, p, rng_state};
 }

+ 21 - 37
csrc/flash_attn_ck/mha_varlen_bwd.cpp

@@ -264,18 +264,14 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
     const int total_q = sizes[0];
     const int batch_size = cu_seqlens_q.numel() - 1;
     const int num_heads = sizes[1];
-    const int head_size_og = dout.size(2);
-    const int head_size_8x = sizes[2];
+    const int head_size = sizes[2];
     const int total_k = k.size(0);
     const int num_heads_k = k.size(1);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
-    TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
-    TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
+    TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
-    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
-
     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
 
@@ -295,11 +291,11 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
 
     // q, k, v, out had been padded in mha_fwd
     // dq_, dk_, dv_ are also padded tensor
-    CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
-    CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
-    CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
-    CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
-    CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
+    CHECK_SHAPE(q, total_q, num_heads, head_size);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size);
+    CHECK_SHAPE(out, total_q, num_heads, head_size);
+    CHECK_SHAPE(dout, total_q, num_heads, head_size);
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
 
@@ -309,7 +305,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
         CHECK_DEVICE(dq);
         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
-        CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
+        CHECK_SHAPE(dq, total_q, num_heads, head_size);
     } else {
         dq = torch::empty_like(q);
     }
@@ -318,7 +314,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
         TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
         CHECK_DEVICE(dk);
         TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
-        CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
+        CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
     } else {
         dk = torch::empty_like(k);
     }
@@ -327,18 +323,11 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
         CHECK_DEVICE(dv);
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
-        CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
+        CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
     } else {
         dv = torch::empty_like(v);
     }
 
-    at::Tensor dout_padded;
-    if (head_size_og % 8 != 0) {
-        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    } else {
-        dout_padded = dout;
-    }
-
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
@@ -347,17 +336,17 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
     at::Tensor dq_accum;
 
     if (!deterministic) {
-        dq_accum = torch::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+        dq_accum = torch::zeros({1, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
     } else {
-        const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
+        const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
         const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
-        dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+        dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
     }
 
     at::Tensor dk_expanded, dv_expanded;
     if (num_heads_k != num_heads) {  // MQA / GQA
-        dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
-        dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
+        dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
+        dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
     } else {
         dk_expanded = dk;
         dv_expanded = dv;
@@ -391,7 +380,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
         ck_tile::stream_config stream_config{stream};
 
         auto traits =
-            get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
+            get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);
 
         auto args =
             get_ck_fmha_varlen_bwd_args(
@@ -401,7 +390,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
                 max_seqlen_k,
                 num_heads,
                 num_heads_k,
-                head_size_8x,
+                head_size,
                 q,
                 k,
                 v,
@@ -410,7 +399,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
                 alibi_slopes_,
                 out,
                 softmax_lse,
-                dout_padded,
+                dout,
                 dq_accum,
                 softmax_d,
                 dq,
@@ -432,13 +421,8 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
 
     // For MQA/GQA we need to sum dK and dV across the groups
     if (num_heads_k != num_heads) {
-        at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
-        at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
-    }
-    if (head_size_og % 8 != 0) {
-        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
+        at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
     }
 
     return { dq, dk, dv, softmax_d };

+ 17 - 36
csrc/flash_attn_ck/mha_varlen_fwd.cpp

@@ -195,7 +195,7 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
 
     const int batch_size = cu_seqlens_q.numel() - 1;
     int num_heads = sizes[1];
-    const int head_size_og = sizes[2];
+    const int head_size = sizes[2];
     const int num_heads_k = k.size(1);
 
     const int max_num_blocks_per_seq = 0;
@@ -211,7 +211,8 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
     const int total_k = k.size(0);
 
     TORCH_CHECK(batch_size > 0, "batch size must be postive");
-    TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
@@ -234,41 +235,24 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
         mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
     }
 
-    CHECK_SHAPE(q, total_q, num_heads, head_size_og);
-    CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(q, total_q, num_heads, head_size);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size);
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
 
-    at::Tensor q_padded, k_padded, v_padded;
-    if (head_size_og % 8 != 0) {
-        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
-    }
-    else {
-        q_padded = q;
-        k_padded = k;
-        v_padded = v;
-    }
-
     at::Tensor out;
     if (out_.has_value()) {
         out = out_.value();
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, total_q, num_heads, head_size_og);
-
-        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+        CHECK_SHAPE(out, total_q, num_heads, head_size);
     }
     else {
-        out = torch::empty_like(q_padded);
+        out = torch::empty_like(q);
     }
 
-    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size_8x = round_multiple(head_size_og, 8);
-
     // Otherwise the kernel will be launched from cuda:0 device
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
@@ -286,6 +270,9 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
         TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
         p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
     }
+    else {
+        p = torch::empty({ 0 }, opts);
+    }
 
     if (zero_tensors)
     {
@@ -319,7 +306,7 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
             get_ck_fmha_varlen_fwd_traits(
                 mask,
                 q_dtype_str,
-                head_size_8x,
+                head_size,
                 has_dropout,
                 has_lse,
                 alibi_slopes_.has_value());
@@ -333,10 +320,10 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
                 max_seqlen_q,
                 num_heads,
                 num_heads_k,
-                head_size_8x,
-                q_padded,
-                k_padded,
-                v_padded,
+                head_size,
+                q,
+                k,
+                v,
                 cu_seqlens_q,
                 cu_seqlens_k,
                 alibi_slopes_,
@@ -357,11 +344,5 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
         softmax_lse.fill_(std::numeric_limits<float>::infinity());
     }
 
-    at::Tensor out_padded = out;
-    if (head_size_og % 8 != 0) {
-        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
-        if (out_.has_value()) { out_.value().copy_(out); }
-    }
-
-    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+    return {out, softmax_lse, p, rng_state};
 }