|
@@ -195,7 +195,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
|
|
}
|
|
|
|
|
|
std::vector<at::Tensor>
|
|
|
-mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
|
|
+mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size, 8)
|
|
|
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
|
|
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
|
@@ -248,18 +248,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
const int batch_size = sizes[0];
|
|
|
const int seqlen_q = sizes[1];
|
|
|
const int num_heads = sizes[2];
|
|
|
- const int head_size_og = dout.size(3); // unpadded hdim
|
|
|
- const int head_size_8x = sizes[3];
|
|
|
+ const int head_size = sizes[3];
|
|
|
const int seqlen_k = k.size(1);
|
|
|
const int num_heads_k = k.size(2);
|
|
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
|
|
- TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
|
|
|
- TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
|
|
|
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
|
|
+ TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
|
|
|
|
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
|
|
- TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
|
|
|
-
|
|
|
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
|
|
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
|
|
|
|
@@ -279,11 +275,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
|
|
|
// q, k, v, out had been padded in mha_fwd
|
|
|
// dq_, dk_, dv_ are also padded tensor
|
|
|
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
|
|
|
- CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
|
|
- CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
|
|
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
|
|
|
- CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
|
|
|
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
|
|
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
|
|
|
+ CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
|
|
|
|
|
|
at::Tensor dq, dk, dv;
|
|
|
if (dq_.has_value()) {
|
|
@@ -291,7 +287,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
|
|
CHECK_DEVICE(dq);
|
|
|
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
|
|
- CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
|
|
|
+ CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
|
|
} else {
|
|
|
dq = torch::empty_like(q);
|
|
|
}
|
|
@@ -300,7 +296,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
|
|
CHECK_DEVICE(dk);
|
|
|
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
|
|
- CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
|
|
+ CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
} else {
|
|
|
dk = torch::empty_like(k);
|
|
|
}
|
|
@@ -309,18 +305,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
|
|
CHECK_DEVICE(dv);
|
|
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
|
|
- CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
|
|
|
+ CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
} else {
|
|
|
dv = torch::empty_like(v);
|
|
|
}
|
|
|
|
|
|
- at::Tensor dout_padded;
|
|
|
- if (head_size_og % 8 != 0) {
|
|
|
- dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
|
|
- } else {
|
|
|
- dout_padded = dout;
|
|
|
- }
|
|
|
-
|
|
|
// Cast to char to avoid compiler warning about narrowing
|
|
|
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
|
|
|
|
@@ -329,17 +318,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
at::Tensor dq_accum;
|
|
|
|
|
|
if (!deterministic) {
|
|
|
- dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
|
|
|
+ dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
|
|
} else {
|
|
|
- const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
|
|
|
+ const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
|
|
|
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
|
|
|
- dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
|
|
|
+ dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
|
|
}
|
|
|
|
|
|
at::Tensor dk_expanded, dv_expanded;
|
|
|
if (num_heads_k != num_heads) { // MQA / GQA
|
|
|
- dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
|
|
|
- dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
|
|
|
+ dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
|
|
+ dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
|
|
} else {
|
|
|
dk_expanded = dk;
|
|
|
dv_expanded = dv;
|
|
@@ -366,7 +355,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
ck_tile::stream_config stream_config{stream};
|
|
|
|
|
|
auto traits =
|
|
|
- get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
|
|
|
+ get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);
|
|
|
|
|
|
auto args =
|
|
|
get_ck_fmha_bwd_args(
|
|
@@ -376,14 +365,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
seqlen_k,
|
|
|
num_heads,
|
|
|
num_heads_k,
|
|
|
- head_size_8x,
|
|
|
+ head_size,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
|
alibi_slopes_,
|
|
|
out,
|
|
|
softmax_lse,
|
|
|
- dout_padded,
|
|
|
+ dout,
|
|
|
dq_accum,
|
|
|
softmax_d,
|
|
|
dq,
|
|
@@ -405,13 +394,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
|
|
|
|
|
|
// For MQA/GQA we need to sum dK and dV across the groups
|
|
|
if (num_heads_k != num_heads) {
|
|
|
- at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
|
|
|
- at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
|
|
|
- }
|
|
|
- if (head_size_og % 8 != 0) {
|
|
|
- dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
|
|
- dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
|
|
- dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
|
|
+ at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
|
|
+ at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
|
|
}
|
|
|
|
|
|
return { dq, dk, dv, softmax_d };
|