|
@@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
|
|
grad_input = None
|
|
|
if ctx.heuristic == -1:
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
- if process_group is not None and sequence_parallel:
|
|
|
+ if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
|
|
|
handle_x.wait()
|
|
|
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
|
|
total_x.reshape(batch_dim, total_x.shape[-1]),
|
|
@@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
|
|
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
|
|
|
else:
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
- if process_group is not None and sequence_parallel:
|
|
|
+ if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
|
|
|
handle_x.wait()
|
|
|
grad_weight1 = F.linear(
|
|
|
grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
|