Ver Fonte

Fixing argument checking when using `seqlenq_ngroups_swapped`. (#976)

When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
Nicolas Patry há 8 meses atrás
pai
commit
5bf201966a
1 ficheiros alterados com 0 adições e 1 exclusões
  1. 0 1
      csrc/flash_attn/flash_api.cpp

+ 0 - 1
csrc/flash_attn/flash_api.cpp

@@ -637,7 +637,6 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, total_q, num_heads, head_size_og);
         CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
         if (seqlenq_ngroups_swapped) {
             out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});