|
@@ -830,7 +830,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
|
|
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
|
|
} else {
|
|
|
- dv = torch::empty_like(k);
|
|
|
+ dv = torch::empty_like(v);
|
|
|
}
|
|
|
|
|
|
at::Tensor dout_padded;
|