|
@@ -637,7 +637,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
|
|
CHECK_DEVICE(out);
|
|
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
|
|
- CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
|
|
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
|
|
if (seqlenq_ngroups_swapped) {
|
|
|
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|