|
@@ -267,6 +267,8 @@ def _layer_norm_fwd(
|
|
|
residual_dtype=None,
|
|
|
is_rms_norm=False,
|
|
|
return_dropout_mask=False,
|
|
|
+ out=None,
|
|
|
+ residual_out=None
|
|
|
):
|
|
|
if residual is not None:
|
|
|
residual_dtype = residual.dtype
|
|
@@ -294,10 +296,13 @@ def _layer_norm_fwd(
|
|
|
assert rowscale.is_contiguous()
|
|
|
assert rowscale.shape == (M,)
|
|
|
# allocate output
|
|
|
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
|
|
- assert y.stride(-1) == 1
|
|
|
+ if out is None:
|
|
|
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
|
|
+ else:
|
|
|
+ assert out.shape == x.shape
|
|
|
+ assert out.stride(-1) == 1
|
|
|
if weight1 is not None:
|
|
|
- y1 = torch.empty_like(y)
|
|
|
+ y1 = torch.empty_like(out)
|
|
|
assert y1.stride(-1) == 1
|
|
|
else:
|
|
|
y1 = None
|
|
@@ -308,9 +313,12 @@ def _layer_norm_fwd(
|
|
|
or rowscale is not None
|
|
|
or x1 is not None
|
|
|
):
|
|
|
- residual_out = torch.empty(
|
|
|
- M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
|
|
- )
|
|
|
+ if residual_out is None:
|
|
|
+ residual_out = torch.empty(
|
|
|
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ assert residual_out.shape == x.shape
|
|
|
assert residual_out.stride(-1) == 1
|
|
|
else:
|
|
|
residual_out = None
|
|
@@ -334,7 +342,7 @@ def _layer_norm_fwd(
|
|
|
with torch.cuda.device(x.device.index):
|
|
|
_layer_norm_fwd_1pass_kernel[(M,)](
|
|
|
x,
|
|
|
- y,
|
|
|
+ out,
|
|
|
weight,
|
|
|
bias,
|
|
|
residual,
|
|
@@ -349,7 +357,7 @@ def _layer_norm_fwd(
|
|
|
mean,
|
|
|
rstd,
|
|
|
x.stride(0),
|
|
|
- y.stride(0),
|
|
|
+ out.stride(0),
|
|
|
residual.stride(0) if residual is not None else 0,
|
|
|
residual_out.stride(0) if residual_out is not None else 0,
|
|
|
x1.stride(0) if x1 is not None else 0,
|
|
@@ -373,7 +381,7 @@ def _layer_norm_fwd(
|
|
|
else:
|
|
|
dropout_mask1 = None
|
|
|
return (
|
|
|
- y,
|
|
|
+ out,
|
|
|
y1,
|
|
|
mean,
|
|
|
rstd,
|
|
@@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
residual_in_fp32=False,
|
|
|
is_rms_norm=False,
|
|
|
return_dropout_mask=False,
|
|
|
+ out=None,
|
|
|
+ residual_out=None
|
|
|
):
|
|
|
x_shape_og = x.shape
|
|
|
# reshape input data into 2D tensor
|
|
@@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
if residual is not None
|
|
|
else (torch.float32 if residual_in_fp32 else None)
|
|
|
)
|
|
|
+ if out is not None:
|
|
|
+ out = out.reshape(-1, out.shape[-1])
|
|
|
+ if residual_out is not None:
|
|
|
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
|
|
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
|
|
x,
|
|
|
weight,
|
|
@@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
residual_dtype=residual_dtype,
|
|
|
is_rms_norm=is_rms_norm,
|
|
|
return_dropout_mask=return_dropout_mask,
|
|
|
+ out=out,
|
|
|
+ residual_out=residual_out
|
|
|
)
|
|
|
ctx.save_for_backward(
|
|
|
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
|
@@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
None,
|
|
|
None,
|
|
|
None,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
)
|
|
|
|
|
|
|
|
@@ -871,6 +889,8 @@ def layer_norm_fn(
|
|
|
residual_in_fp32=False,
|
|
|
is_rms_norm=False,
|
|
|
return_dropout_mask=False,
|
|
|
+ out=None,
|
|
|
+ residual_out=None
|
|
|
):
|
|
|
return LayerNormFn.apply(
|
|
|
x,
|
|
@@ -887,6 +907,8 @@ def layer_norm_fn(
|
|
|
residual_in_fp32,
|
|
|
is_rms_norm,
|
|
|
return_dropout_mask,
|
|
|
+ out,
|
|
|
+ residual_out
|
|
|
)
|
|
|
|
|
|
|
|
@@ -904,6 +926,8 @@ def rms_norm_fn(
|
|
|
prenorm=False,
|
|
|
residual_in_fp32=False,
|
|
|
return_dropout_mask=False,
|
|
|
+ out=None,
|
|
|
+ residual_out=None
|
|
|
):
|
|
|
return LayerNormFn.apply(
|
|
|
x,
|
|
@@ -920,6 +944,8 @@ def rms_norm_fn(
|
|
|
residual_in_fp32,
|
|
|
True,
|
|
|
return_dropout_mask,
|
|
|
+ out,
|
|
|
+ residual_out
|
|
|
)
|
|
|
|
|
|
|