|
@@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
|
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
|
|
|
MaxOp<float> max_op;
|
|
|
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
|
|
- lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
|
|
+ lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
|
|
float lse_sum = expf(lse_accum(0) - lse_max);
|
|
|
#pragma unroll
|
|
|
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|