소스 검색

Fix dv = torch::empty_like(k) for mha_bwd_varlen as well

Tri Dao 1 년 전
부모
커밋
d9a5cb291c
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      csrc/flash_attn/flash_api.cpp

+ 1 - 1
csrc/flash_attn/flash_api.cpp

@@ -1069,7 +1069,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
         CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
     } else {
-        dv = torch::empty_like(k);
+        dv = torch::empty_like(v);
     }
 
     at::Tensor dout_padded;