|
@@ -42,7 +42,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
float p_dropout,
|
|
|
float softmax_scale,
|
|
|
int window_size_left,
|
|
|
- int window_size_right) {
|
|
|
+ int window_size_right,
|
|
|
+ bool seqlenq_ngroups_swapped=false) {
|
|
|
|
|
|
// Reset the parameters
|
|
|
memset(¶ms, 0, sizeof(params));
|
|
@@ -69,6 +70,10 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
params.k_batch_stride = k.stride(0);
|
|
|
params.v_batch_stride = v.stride(0);
|
|
|
params.o_batch_stride = out.stride(0);
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ params.q_batch_stride *= seqlen_q;
|
|
|
+ params.o_batch_stride *= seqlen_q;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
|
@@ -251,6 +256,31 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
|
|
+ const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
|
|
|
+ const int head_size_rounded, float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
|
|
|
+
|
|
|
+ // This needs to match with run_mha_fwd_splitkv_dispatch
|
|
|
+ const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
|
|
+ const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
|
|
|
+ // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
|
|
+ // In any case we don't expect seqlen_q to be larger than 64 for inference.
|
|
|
+ const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
|
|
|
+ params.num_splits = num_splits;
|
|
|
+ if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
|
|
+ if (num_splits < 1) {
|
|
|
+ params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
|
|
+ }
|
|
|
+ if (params.num_splits > 1) {
|
|
|
+ at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
|
|
+ at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
|
|
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
|
|
+ params.oaccum_ptr = out_accum.data_ptr();
|
|
|
+ }
|
|
|
+ TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
std::vector<at::Tensor>
|
|
|
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
|
@@ -382,23 +412,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
window_size_left,
|
|
|
window_size_right);
|
|
|
|
|
|
- // This needs to match with run_mha_fwd_splitkv_dispatch
|
|
|
- const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
|
|
- const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
|
|
- // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
|
|
- // In any case we don't expect seqlen_q to be larger than 64 for inference.
|
|
|
- const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
|
|
|
- params.num_splits = 1;
|
|
|
- if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
|
|
- params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
|
|
- if (params.num_splits > 1) {
|
|
|
- at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
|
|
- at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
|
|
- params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
|
|
- params.oaccum_ptr = out_accum.data_ptr();
|
|
|
- }
|
|
|
- TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
|
|
- }
|
|
|
+
|
|
|
+ set_params_splitkv(params, batch_size, num_heads,
|
|
|
+ head_size, seqlen_k, seqlen_q,
|
|
|
+ head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
|
|
|
|
|
// number of times random will be generated per thread, to offset philox counter in thc random
|
|
|
// state
|
|
@@ -454,7 +471,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
}
|
|
|
|
|
|
std::vector<at::Tensor>
|
|
|
-mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
|
+mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
|
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
|
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
@@ -462,18 +479,17 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
|
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
|
|
- const int max_seqlen_q,
|
|
|
+ int max_seqlen_q,
|
|
|
const int max_seqlen_k,
|
|
|
const float p_dropout,
|
|
|
const float softmax_scale,
|
|
|
const bool zero_tensors,
|
|
|
- const bool is_causal,
|
|
|
+ bool is_causal,
|
|
|
int window_size_left,
|
|
|
int window_size_right,
|
|
|
const bool return_softmax,
|
|
|
c10::optional<at::Generator> gen_) {
|
|
|
|
|
|
- if (is_causal) { window_size_right = 0; }
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
|
@@ -505,12 +521,30 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
|
|
|
const auto sizes = q.sizes();
|
|
|
|
|
|
- const int total_q = sizes[0];
|
|
|
const int batch_size = cu_seqlens_q.numel() - 1;
|
|
|
- const int num_heads = sizes[1];
|
|
|
+ int num_heads = sizes[1];
|
|
|
const int head_size_og = sizes[2];
|
|
|
const int total_k = k.size(0);
|
|
|
const int num_heads_k = k.size(1);
|
|
|
+
|
|
|
+ if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
|
|
+ if (is_causal) { window_size_right = 0; }
|
|
|
+
|
|
|
+ void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
|
|
|
+
|
|
|
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
|
|
+ // H/t Daniel Haziza
|
|
|
+ const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ const int ngroups = num_heads / num_heads_k;
|
|
|
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
|
|
+ max_seqlen_q = ngroups;
|
|
|
+ num_heads = num_heads_k;
|
|
|
+ cu_seqlens_q_d = nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int total_q = q.sizes()[0];
|
|
|
+
|
|
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
|
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
|
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
|
@@ -588,7 +622,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
num_heads, num_heads_k,
|
|
|
head_size, head_size_rounded,
|
|
|
q_padded, k_padded, v_padded, out,
|
|
|
- cu_seqlens_q.data_ptr(),
|
|
|
+ cu_seqlens_q_d,
|
|
|
cu_seqlens_k.data_ptr(),
|
|
|
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
|
|
return_softmax ? p.data_ptr() : nullptr,
|
|
@@ -596,7 +630,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
p_dropout,
|
|
|
softmax_scale,
|
|
|
window_size_left,
|
|
|
- window_size_right);
|
|
|
+ window_size_right,
|
|
|
+ seqlenq_ngroups_swapped);
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ // Only apply split-k for decoding
|
|
|
+ set_params_splitkv(params, batch_size, num_heads,
|
|
|
+ head_size, max_seqlen_k, max_seqlen_q,
|
|
|
+ head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
|
|
+ }
|
|
|
|
|
|
// number of times random will be generated per thread, to offset philox counter in thc random
|
|
|
// state
|
|
@@ -642,6 +683,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
if (out_.has_value()) { out_.value().copy_(out); }
|
|
|
}
|
|
|
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ long size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
|
|
|
+ long size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
|
|
|
+ out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
|
|
+ out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
|
|
+ q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
|
|
+ softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
|
|
|
+ }
|
|
|
+
|
|
|
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
|
|
}
|
|
|
|
|
@@ -1367,23 +1417,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
|
|
|
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
|
|
|
}
|
|
|
- // This needs to match with run_mha_fwd_splitkv_dispatch
|
|
|
- const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
|
|
- const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
|
|
- // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
|
|
- // In any case we don't expect seqlen_q to be larger than 64 for inference.
|
|
|
- const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
|
|
|
- params.num_splits = num_splits;
|
|
|
- if (num_splits < 1) {
|
|
|
- params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
|
|
- }
|
|
|
- TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
|
|
- if (params.num_splits > 1) {
|
|
|
- at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
|
|
- at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
|
|
- params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
|
|
- params.oaccum_ptr = out_accum.data_ptr();
|
|
|
- }
|
|
|
+
|
|
|
+ set_params_splitkv(params, batch_size, num_heads,
|
|
|
+ head_size, seqlen_k, seqlen_q,
|
|
|
+ head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
|
|
|
|
|
|
if (alibi_slopes_.has_value()) {
|
|
|
auto alibi_slopes = alibi_slopes_.value();
|