|
@@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
|
|
params.num_splits = num_splits;
|
|
|
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
|
|
if (num_splits < 1) {
|
|
|
- params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
|
|
+ // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
|
|
|
+ params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
|
|
|
}
|
|
|
if (params.num_splits > 1) {
|
|
|
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
|
@@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
|
|
// H/t Daniel Haziza
|
|
|
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
|
|
+ const int ngroups = num_heads / num_heads_k;
|
|
|
if (seqlenq_ngroups_swapped) {
|
|
|
- const int ngroups = num_heads / num_heads_k;
|
|
|
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
|
|
seqlen_q = ngroups;
|
|
|
num_heads = num_heads_k;
|
|
@@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
|
|
CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
|
|
+ CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
|
|
+ }
|
|
|
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
|
|
} else {
|
|
|
out = torch::empty_like(q_padded);
|
|
@@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
|
|
// H/t Daniel Haziza
|
|
|
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
|
|
+ const int ngroups = num_heads / num_heads_k;
|
|
|
if (seqlenq_ngroups_swapped) {
|
|
|
- const int ngroups = num_heads / num_heads_k;
|
|
|
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
|
|
max_seqlen_q = ngroups;
|
|
|
num_heads = num_heads_k;
|
|
@@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
|
|
+ CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
|
|
+ if (seqlenq_ngroups_swapped) {
|
|
|
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
|
|
+ }
|
|
|
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
|
|
} else {
|
|
|
out = torch::empty_like(q_padded);
|