|
@@ -260,7 +260,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const float p_dropout,
|
|
|
const float softmax_scale,
|
|
|
bool is_causal,
|
|
|
- const int window_size_left,
|
|
|
+ int window_size_left,
|
|
|
int window_size_right,
|
|
|
const bool return_softmax,
|
|
|
c10::optional<at::Generator> gen_) {
|
|
@@ -300,6 +300,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
|
|
|
|
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
|
|
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
|
|
+
|
|
|
// causal=true is the same as causal=false in this case
|
|
|
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
|
|
if (is_causal) { window_size_right = 0; }
|
|
@@ -465,7 +468,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
const float softmax_scale,
|
|
|
const bool zero_tensors,
|
|
|
const bool is_causal,
|
|
|
- const int window_size_left,
|
|
|
+ int window_size_left,
|
|
|
int window_size_right,
|
|
|
const bool return_softmax,
|
|
|
c10::optional<at::Generator> gen_) {
|
|
@@ -512,6 +515,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
|
|
|
|
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
|
|
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
|
|
+
|
|
|
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
|
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
|
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
|
@@ -675,7 +681,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
const float p_dropout, // probability to drop
|
|
|
const float softmax_scale,
|
|
|
const bool is_causal,
|
|
|
- const int window_size_left,
|
|
|
+ int window_size_left,
|
|
|
int window_size_right,
|
|
|
const bool deterministic,
|
|
|
c10::optional<at::Generator> gen_,
|
|
@@ -738,6 +744,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
|
|
|
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
|
|
|
|
|
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
|
|
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
|
|
+
|
|
|
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
|
|
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
|
@@ -912,7 +921,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
const float softmax_scale,
|
|
|
const bool zero_tensors,
|
|
|
const bool is_causal,
|
|
|
- const int window_size_left,
|
|
|
+ int window_size_left,
|
|
|
int window_size_right,
|
|
|
const bool deterministic,
|
|
|
c10::optional<at::Generator> gen_,
|
|
@@ -979,6 +988,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
|
|
|
|
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
|
|
|
|
|
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
|
|
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
|
|
+
|
|
|
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
|
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
|
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
|
@@ -1160,7 +1172,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const float softmax_scale,
|
|
|
bool is_causal,
|
|
|
- const int window_size_left,
|
|
|
+ int window_size_left,
|
|
|
int window_size_right,
|
|
|
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
|
|
int num_splits
|
|
@@ -1216,6 +1228,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
num_heads = num_heads_k;
|
|
|
}
|
|
|
|
|
|
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
|
|
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
|
|
+
|
|
|
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
|
|
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
|
|
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|