|
@@ -7,20 +7,20 @@ from torch.nn import init
|
|
|
import dropout_layer_norm
|
|
|
|
|
|
|
|
|
-def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
|
|
- residual_in_fp32=False, is_rms_norm=False):
|
|
|
+def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
|
|
+ epsilon, residual_in_fp32=False, is_rms_norm=False):
|
|
|
""" Assume that arguments are contiguous
|
|
|
"""
|
|
|
hidden_size = gamma.numel()
|
|
|
x0mat = x0.view((-1, hidden_size))
|
|
|
- x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
|
|
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
|
|
rowscale = rowscale.view(-1) if rowscale is not None else None
|
|
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
|
|
- x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
|
|
+ x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
|
|
1.0, 0, None, residual_in_fp32, is_rms_norm
|
|
|
)
|
|
|
# dmask is None if dropout_p == 0.0
|
|
|
- # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
|
|
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
|
|
|
|
|
|
|
@@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
|
|
dropout_p, has_residual, is_rms_norm=False):
|
|
|
""" Assume that arguments are contiguous
|
|
|
dx == None means that it was a post-norm architecture
|
|
|
- (x = drop(x0) + x1 was not returned in the fwd).
|
|
|
+ (x = drop(x0) + residual was not returned in the fwd).
|
|
|
x0 must not be None if we have colscale.
|
|
|
"""
|
|
|
hidden_size = gamma.numel()
|
|
@@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
|
|
rowscale = rowscale.view(-1) if rowscale is not None else None
|
|
|
if colscale is not None:
|
|
|
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
|
|
- dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
|
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
|
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
|
|
|
dropout_p, 1.0, 0, has_residual, is_rms_norm
|
|
|
)
|
|
|
- # dx1mat is None if not has_residual
|
|
|
+ # dresidualmat is None if not has_residual
|
|
|
if colscale is None:
|
|
|
- return dx0mat, dx1mat, dgamma, dbeta
|
|
|
+ return dx0mat, dresidualmat, dgamma, dbeta
|
|
|
else:
|
|
|
dcolscale = rest[0]
|
|
|
- return dx0mat, dx1mat, dgamma, dbeta, dcolscale
|
|
|
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
|
|
|
|
|
|
|
|
-def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
|
|
|
- dropout_p, epsilon, rowscale_const, out_numrows,
|
|
|
- residual_in_fp32=False, is_rms_norm=False):
|
|
|
+def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
|
|
+ out_subset, dropout_p, epsilon, rowscale_const,
|
|
|
+ out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
|
|
""" Assume that arguments are contiguous
|
|
|
"""
|
|
|
hidden_size = gamma.numel()
|
|
|
x0mat = x0.view((-1, hidden_size))
|
|
|
- x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
|
|
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
|
|
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
|
|
out_subset = out_subset.view(-1) if out_subset is not None else None
|
|
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
|
|
- x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
+ x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
|
|
|
)
|
|
|
# dmask is None if dropout_p == 0.0
|
|
|
- # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
|
|
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
|
|
|
|
|
|
|
@@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
|
|
x0_numrows, has_residual, is_rms_norm=False):
|
|
|
""" Assume that arguments are contiguous
|
|
|
dx == None means that it was a post-norm architecture
|
|
|
- (x = drop(x0) + x1 was not returned in the fwd).
|
|
|
+ (x = drop(x0) + residual was not returned in the fwd).
|
|
|
x0 must not be None if we have colscale.
|
|
|
"""
|
|
|
hidden_size = gamma.numel()
|
|
@@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
|
|
out_subset = out_subset.view(-1) if out_subset is not None else None
|
|
|
if colscale is not None:
|
|
|
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
|
|
- dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
|
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
|
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
|
|
|
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
|
|
|
)
|
|
|
- # dx1mat is None if not has_residual
|
|
|
+ # dresidualmat is None if not has_residual
|
|
|
if colscale is None:
|
|
|
- return dx0mat, dx1mat, dgamma, dbeta
|
|
|
+ return dx0mat, dresidualmat, dgamma, dbeta
|
|
|
else:
|
|
|
dcolscale = rest[0]
|
|
|
- return dx0mat, dx1mat, dgamma, dbeta, dcolscale
|
|
|
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
|
|
|
|
|
|
|
|
class DropoutAddLayerNormFn(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
- def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
|
|
+ def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
|
|
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
|
|
x0 = x0.contiguous()
|
|
|
- x1 = x1.contiguous() if x1 is not None else None
|
|
|
+ residual = residual.contiguous() if residual is not None else None
|
|
|
gamma = gamma.contiguous()
|
|
|
beta = beta.contiguous() if beta is not None else None
|
|
|
rowscale = rowscale.contiguous() if rowscale is not None else None
|
|
|
colscale = colscale.contiguous() if colscale is not None else None
|
|
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
|
|
- x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
|
|
+ x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
|
|
residual_in_fp32, is_rms_norm
|
|
|
)
|
|
|
# Only need to save x0 if we need to compute gradient wrt colscale
|
|
@@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
|
|
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
|
|
|
ctx.prenorm = prenorm
|
|
|
ctx.dropout_p = dropout_p
|
|
|
- ctx.has_residual = x1 is not None
|
|
|
+ ctx.has_residual = residual is not None
|
|
|
ctx.is_rms_norm = is_rms_norm
|
|
|
ctx.has_beta = beta is not None
|
|
|
if not return_dmask:
|
|
@@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
|
|
# x0 is None if colscale is None
|
|
|
dropout_p = ctx.dropout_p
|
|
|
has_residual = ctx.has_residual
|
|
|
- dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
|
|
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
|
|
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
|
|
|
ctx.is_rms_norm
|
|
|
)
|
|
|
dx0 = dx0mat.view(x.shape)
|
|
|
- dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
|
|
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
|
|
dcolscale = rest[0] if colscale is not None else None
|
|
|
- return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
|
|
|
- None, None, None, None)
|
|
|
+ return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
|
|
|
+ None, None, None, None, None)
|
|
|
|
|
|
|
|
|
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
- def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
+ def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
rowscale_const, out_numrows, residual_in_fp32=False,
|
|
|
prenorm=False, is_rms_norm=False, return_dmask=False):
|
|
|
x0 = x0.contiguous()
|
|
|
- x1 = x1.contiguous() if x1 is not None else None
|
|
|
+ residual = residual.contiguous() if residual is not None else None
|
|
|
gamma = gamma.contiguous()
|
|
|
beta = beta.contiguous() if beta is not None else None
|
|
|
colscale = colscale.contiguous() if colscale is not None else None
|
|
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
|
|
- x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
+ x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
|
|
)
|
|
|
# Only need to save x0 if we need to compute gradient wrt colscale
|
|
@@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|
|
ctx.dropout_p = dropout_p
|
|
|
ctx.rowscale_const = rowscale_const
|
|
|
ctx.x0_numrows = x0.shape[:-1].numel()
|
|
|
- ctx.has_residual = x1 is not None
|
|
|
+ ctx.has_residual = residual is not None
|
|
|
ctx.is_rms_norm = is_rms_norm
|
|
|
ctx.has_beta = beta is not None
|
|
|
z_shape = (-1, *x0.shape[1:])
|
|
@@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|
|
# x0 is None if colscale is None
|
|
|
dropout_p = ctx.dropout_p
|
|
|
has_residual = ctx.has_residual
|
|
|
- dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
|
|
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
|
|
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
|
|
|
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
|
|
|
)
|
|
|
dx0 = dx0mat.view(-1, *x.shape[1:])
|
|
|
- dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
|
|
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
|
|
dcolscale = rest[0] if colscale is not None else None
|
|
|
- return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
|
|
|
- None, None, None, None, None, None, None)
|
|
|
+ return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
|
|
|
+ None, None, None, None, None, None, None, None)
|
|
|
|
|
|
|
|
|
def layer_norm(x, weight, bias, epsilon):
|
|
|
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
|
|
|
|
|
|
|
|
-def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
|
|
|
- prenorm=False, residual_in_fp32=False,
|
|
|
+def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
|
|
+ layerscale=None, prenorm=False, residual_in_fp32=False,
|
|
|
return_dropout_mask=False):
|
|
|
- """residual_in_fp32 only has an effect if x1 is None.
|
|
|
- Otherwise residual dtype is x1.dtype.
|
|
|
+ """residual_in_fp32 only has an effect if residual is None.
|
|
|
+ Otherwise residual dtype is residual.dtype.
|
|
|
"""
|
|
|
return DropoutAddLayerNormFn.apply(
|
|
|
- x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
|
|
+ x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
|
|
False, return_dropout_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
-def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
|
|
|
+def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
|
|
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
|
|
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
|
|
return_dropout_mask=False):
|
|
|
- """residual_in_fp32 only has an effect if x1 is None.
|
|
|
- Otherwise residual dtype is x1.dtype.
|
|
|
+ """residual_in_fp32 only has an effect if residual is None.
|
|
|
+ Otherwise residual dtype is residual.dtype.
|
|
|
"""
|
|
|
return DropoutAddLayerNormSubsetFn.apply(
|
|
|
- x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
+ x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
|
|
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
|
|
|
)
|
|
|
|
|
@@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
|
|
|
init.ones_(self.weight)
|
|
|
init.zeros_(self.bias)
|
|
|
|
|
|
- def forward(self, x0, x1=None):
|
|
|
- return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
|
|
|
+ def forward(self, x0, residual=None):
|
|
|
+ return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
|
|
|
self.p if self.training else 0.0, self.epsilon,
|
|
|
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|