فهرست منبع

[LayerNorm] Add option to write result to out and residual_out

Tri Dao 7 ماه پیش
والد
کامیت
bcd918f275
1فایلهای تغییر یافته به همراه35 افزوده شده و 9 حذف شده
  1. 35 9
      flash_attn/ops/triton/layer_norm.py

+ 35 - 9
flash_attn/ops/triton/layer_norm.py

@@ -267,6 +267,8 @@ def _layer_norm_fwd(
     residual_dtype=None,
     is_rms_norm=False,
     return_dropout_mask=False,
+    out=None,
+    residual_out=None
 ):
     if residual is not None:
         residual_dtype = residual.dtype
@@ -294,10 +296,13 @@ def _layer_norm_fwd(
         assert rowscale.is_contiguous()
         assert rowscale.shape == (M,)
     # allocate output
-    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
-    assert y.stride(-1) == 1
+    if out is None:
+        out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
+    else:
+        assert out.shape == x.shape
+    assert out.stride(-1) == 1
     if weight1 is not None:
-        y1 = torch.empty_like(y)
+        y1 = torch.empty_like(out)
         assert y1.stride(-1) == 1
     else:
         y1 = None
@@ -308,9 +313,12 @@ def _layer_norm_fwd(
         or rowscale is not None
         or x1 is not None
     ):
-        residual_out = torch.empty(
-            M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
-        )
+        if residual_out is None:
+            residual_out = torch.empty(
+                M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
+            )
+        else:
+            assert residual_out.shape == x.shape
         assert residual_out.stride(-1) == 1
     else:
         residual_out = None
@@ -334,7 +342,7 @@ def _layer_norm_fwd(
     with torch.cuda.device(x.device.index):
         _layer_norm_fwd_1pass_kernel[(M,)](
             x,
-            y,
+            out,
             weight,
             bias,
             residual,
@@ -349,7 +357,7 @@ def _layer_norm_fwd(
             mean,
             rstd,
             x.stride(0),
-            y.stride(0),
+            out.stride(0),
             residual.stride(0) if residual is not None else 0,
             residual_out.stride(0) if residual_out is not None else 0,
             x1.stride(0) if x1 is not None else 0,
@@ -373,7 +381,7 @@ def _layer_norm_fwd(
     else:
         dropout_mask1 = None
     return (
-        y,
+        out,
         y1,
         mean,
         rstd,
@@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
         residual_in_fp32=False,
         is_rms_norm=False,
         return_dropout_mask=False,
+        out=None,
+        residual_out=None
     ):
         x_shape_og = x.shape
         # reshape input data into 2D tensor
@@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
             if residual is not None
             else (torch.float32 if residual_in_fp32 else None)
         )
+        if out is not None:
+            out = out.reshape(-1, out.shape[-1])
+        if residual_out is not None:
+            residual_out = residual_out.reshape(-1, residual_out.shape[-1])
         y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
             x,
             weight,
@@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
             residual_dtype=residual_dtype,
             is_rms_norm=is_rms_norm,
             return_dropout_mask=return_dropout_mask,
+            out=out,
+            residual_out=residual_out
         )
         ctx.save_for_backward(
             residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
             None,
             None,
             None,
+            None,
+            None,
         )
 
 
@@ -871,6 +889,8 @@ def layer_norm_fn(
     residual_in_fp32=False,
     is_rms_norm=False,
     return_dropout_mask=False,
+    out=None,
+    residual_out=None
 ):
     return LayerNormFn.apply(
         x,
@@ -887,6 +907,8 @@ def layer_norm_fn(
         residual_in_fp32,
         is_rms_norm,
         return_dropout_mask,
+        out,
+        residual_out
     )
 
 
@@ -904,6 +926,8 @@ def rms_norm_fn(
     prenorm=False,
     residual_in_fp32=False,
     return_dropout_mask=False,
+    out=None,
+    residual_out=None
 ):
     return LayerNormFn.apply(
         x,
@@ -920,6 +944,8 @@ def rms_norm_fn(
         residual_in_fp32,
         True,
         return_dropout_mask,
+        out,
+        residual_out
     )