Explorar o código

[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)

when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
GAOXinyu hai 1 ano
pai
achega
0cb595ad94
Modificáronse 1 ficheiros con 2 adicións e 2 borrados
  1. 2 2
      flash_attn/ops/fused_dense.py

+ 2 - 2
flash_attn/ops/fused_dense.py

@@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function):
             grad_input = None
         if ctx.heuristic == -1:
             if ctx.needs_input_grad[1]:
-                if process_group is not None and sequence_parallel:
+                if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
                     handle_x.wait()
                 grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
                     total_x.reshape(batch_dim, total_x.shape[-1]),
@@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function):
                 grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
         else:
             if ctx.needs_input_grad[1]:
-                if process_group is not None and sequence_parallel:
+                if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
                     handle_x.wait()
                 grad_weight1 = F.linear(
                     grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()