Browse Source

Fix typo with lse_max == -INFINITY

Tri Dao 1 năm trước cách đây
mục cha
commit
31920dda5f
1 tập tin đã thay đổi với 1 bổ sung1 xóa
  1. 1 1
      csrc/flash_attn/src/flash_fwd_kernel.h

+ 1 - 1
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
     MaxOp<float> max_op;
     lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
-    lse_max == lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
+    lse_max = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
     float lse_sum = expf(lse_accum(0) - lse_max);
     #pragma unroll
     for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }