Browse Source

[LayerNorm] Rename x1 -> residual

Tri Dao 2 years ago
parent
commit
eb33e587e9

+ 3 - 3
csrc/layer_norm/ln.h

@@ -59,7 +59,7 @@ struct ParamsBase {
 
     // Common data pointers.
     void *x0;
-    void *x1;
+    void *residual;
     void *x;
     void *dmask;
     void *mu;
@@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
         , dgamma_part(nullptr)
         , dcolscale_part(nullptr)
         , dx0(nullptr)
-        , dx1(nullptr)
+        , dresidual(nullptr)
         , dbeta(nullptr)
         , dgamma(nullptr)
         , dcolscale(nullptr)
@@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {
 
     // Output: Dgrad.
     void *dx0;
-    void *dx1;
+    void *dresidual;
     // Output: Wgrad.
     void *dbeta;
     void *dgamma;

+ 15 - 15
csrc/layer_norm/ln_api.cpp

@@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input: BxSxhidden_size
-                                           c10::optional<const at::Tensor> &x1_,      // Residual: BxSxhidden_size
+                                           c10::optional<const at::Tensor> &residual_,  // Residual: BxSxhidden_size
                                            const at::Tensor &gamma,   // hidden_size
                                            c10::optional<const at::Tensor> &beta_,   // hidden_size
                                            c10::optional<const at::Tensor> &rowscale_,      // BxS
@@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
                                            bool is_rms_norm=false
 ) {
     auto itype = x0.scalar_type();
-    auto rtype = x1_.has_value()
-        ? x1_.value().scalar_type()
+    auto rtype = residual_.has_value()
+        ? residual_.value().scalar_type()
         : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
     auto wtype = gamma.scalar_type();
     auto otype = itype;
@@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
         TORCH_CHECK(gamma.sizes() == beta.sizes());
     }
 
-    if (x1_.has_value()) {
-        auto x1 = x1_.value();
-        TORCH_CHECK(x1.is_cuda())
-        TORCH_CHECK(x1.is_contiguous());
-        TORCH_CHECK(x1.sizes() == sizes);
+    if (residual_.has_value()) {
+        auto residual = residual_.value();
+        TORCH_CHECK(residual.is_cuda())
+        TORCH_CHECK(residual.is_contiguous());
+        TORCH_CHECK(residual.sizes() == sizes);
     }
 
     if (rowscale_.has_value()) {
@@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 
     auto opts = x0.options();
 
-    bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
+    bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
     at::Tensor x;
     if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
     at::Tensor dmask;
@@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
     TORCH_CHECK(dropout_p < 1.f);
     launch_params.params.dropout_keep_p = 1.f - dropout_p;
-    launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
+    launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
     launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
     launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
@@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     auto opts = x.options();
 
     auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
-    at::Tensor dx1;
-    if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
+    at::Tensor dresidual;
+    if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
     auto dgamma = torch::empty_like(gamma);
     auto dbeta = torch::empty_like(gamma);
     at::Tensor dcolscale;
@@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     launch_params.props = at::cuda::getCurrentDeviceProperties();
     TORCH_CHECK(dropout_p < 1.f);
     launch_params.params.dropout_keep_p = 1.f - dropout_p;
-    launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
+    launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
     launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
     launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
@@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
 
     launcher(launch_params, false);
 
-    std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
+    std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
     if (colscale_.has_value()) {
         result.push_back(dcolscale);
         result.push_back(dcolscale_part);
@@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   m.doc() = "CUDA DropoutAddLayerNorm";
   m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
-        py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
+        py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
         py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
         py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
         py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);

+ 4 - 4
csrc/layer_norm/ln_bwd_kernels.cuh

@@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
 
     extern __shared__ char smem_[];
 
-    const bool has_residual = params.dx1 != nullptr;
+    const bool has_residual = params.dresidual != nullptr;
     const bool prenorm = params.dx != nullptr;
 
     const index_t tidx = threadIdx.x;
@@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
         for( int it = 0; it < LDGS; it++ ) {
             if (Is_even_cols || (it < num_valid_ldgs)) {
                 Ivec dx0;
-                Rvec dx1;
+                Rvec dresidual;
                 Ivec x0;
                 if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
                 #pragma unroll
@@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
                     } else {
                         dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
                     }
-                    if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
+                    if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
                     if (save_dx0) {
                         compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
                         if (Is_dropout) {
@@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
                         }
                     }
                 }
-                if (has_residual) { dx1.store_to(params.dx1, idx_x); }
+                if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
                 if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
                 idx_x += Ktraits::VEC_COLS_PER_LDG;
                 idx_x0 += Ktraits::VEC_COLS_PER_LDG;

+ 5 - 5
csrc/layer_norm/ln_fwd_kernels.cuh

@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
     using Stats = typename Ktraits::Stats;
     using stats_t = typename Stats::stats_t;
 
-    const bool has_residual = params.x1 != nullptr;
+    const bool has_residual = params.residual != nullptr;
     const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
 
     extern __shared__ char smem_[];
@@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
         for( int it = 0; it < LDGS; it++ ) {
             if (Is_even_cols || (it < num_valid_ldgs)) {
                 Ivec x0;
-                Rvec x1;
+                Rvec residual;
                 Rvec x;
                 Mvec dmask;
                 if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
-                if (has_residual) { x1.load_from(params.x1, idx_x); }
+                if (has_residual) { residual.load_from(params.residual, idx_x); }
                 #pragma unroll
                 for( int jt = 0; jt < NUM_ELTS; jt++ ) {
                     // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
@@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
                         compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
                         x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
                         if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
-                        x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
+                        x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
                     } else {
-                        x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
+                        x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
                     }
                     if (save_x) { x.data.elt[jt] = x_ij; }
                     xf[it * NUM_ELTS + jt] = x_ij;

+ 2 - 2
flash_attn/models/gpt.py

@@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel):
                 residual = (dropped + residual) if residual is not None else dropped
                 hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
             else:
-                # Set prenorm=False here since we don't need to the residual
+                # Set prenorm=False here since we don't need the residual
                 hidden_states = dropout_add_layer_norm(
                     hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
                     self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
@@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
         # Previous: Attn / MLP -> Dropout -> Add -> LN
         # Current: Dropout -> Add -> LN -> Attn / MLP
         if 'transformer.ln_0.weight' in state_dict:
-            n_layers = self.config.num_hidden_layers
+            n_layers = len(self.transformer.layers)
             ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
             ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
             state_dict['transformer.ln_f.weight'] = ln_weight

+ 48 - 48
flash_attn/ops/layer_norm.py

@@ -7,20 +7,20 @@ from torch.nn import init
 import dropout_layer_norm
 
 
-def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
-                                    residual_in_fp32=False, is_rms_norm=False):
+def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
+                                    epsilon, residual_in_fp32=False, is_rms_norm=False):
     """ Assume that arguments are contiguous
     """
     hidden_size = gamma.numel()
     x0mat = x0.view((-1, hidden_size))
-    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
+    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
     rowscale = rowscale.view(-1) if rowscale is not None else None
     zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
-        x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
+        x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
         1.0, 0, None, residual_in_fp32, is_rms_norm
     )
     # dmask is None if dropout_p == 0.0
-    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
+    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
     return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
 
 
@@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
                                      dropout_p, has_residual, is_rms_norm=False):
     """ Assume that arguments are contiguous
     dx == None means that it was a post-norm architecture
-    (x = drop(x0) + x1 was not returned in the fwd).
+    (x = drop(x0) + residual was not returned in the fwd).
     x0 must not be None if we have colscale.
     """
     hidden_size = gamma.numel()
@@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
     rowscale = rowscale.view(-1) if rowscale is not None else None
     if colscale is not None:
         assert x0 is not None, 'x0 is required to compute the gradient of colscale'
-    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
+    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
         dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
         dropout_p, 1.0, 0, has_residual, is_rms_norm
     )
-    # dx1mat is None if not has_residual
+    # dresidualmat is None if not has_residual
     if colscale is None:
-        return dx0mat, dx1mat, dgamma, dbeta
+        return dx0mat, dresidualmat, dgamma, dbeta
     else:
         dcolscale = rest[0]
-        return dx0mat, dx1mat, dgamma, dbeta, dcolscale
+        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
 
 
-def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
-                                           dropout_p, epsilon, rowscale_const, out_numrows,
-                                           residual_in_fp32=False, is_rms_norm=False):
+def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
+                                           out_subset, dropout_p, epsilon, rowscale_const,
+                                           out_numrows, residual_in_fp32=False, is_rms_norm=False):
     """ Assume that arguments are contiguous
     """
     hidden_size = gamma.numel()
     x0mat = x0.view((-1, hidden_size))
-    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
+    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
     x0_subset = x0_subset.view(-1) if x0_subset is not None else None
     out_subset = out_subset.view(-1) if out_subset is not None else None
     zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
-        x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
+        x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
         rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
     )
     # dmask is None if dropout_p == 0.0
-    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
+    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
     return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
 
 
@@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
                                             x0_numrows, has_residual, is_rms_norm=False):
     """ Assume that arguments are contiguous
     dx == None means that it was a post-norm architecture
-    (x = drop(x0) + x1 was not returned in the fwd).
+    (x = drop(x0) + residual was not returned in the fwd).
     x0 must not be None if we have colscale.
     """
     hidden_size = gamma.numel()
@@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
     out_subset = out_subset.view(-1) if out_subset is not None else None
     if colscale is not None:
         assert x0 is not None, 'x0 is required to compute the gradient of colscale'
-    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
+    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
         dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
         dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
     )
-    # dx1mat is None if not has_residual
+    # dresidualmat is None if not has_residual
     if colscale is None:
-        return dx0mat, dx1mat, dgamma, dbeta
+        return dx0mat, dresidualmat, dgamma, dbeta
     else:
         dcolscale = rest[0]
-        return dx0mat, dx1mat, dgamma, dbeta, dcolscale
+        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
 
 
 class DropoutAddLayerNormFn(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
+    def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
                 residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
         x0 = x0.contiguous()
-        x1 = x1.contiguous() if x1 is not None else None
+        residual = residual.contiguous() if residual is not None else None
         gamma = gamma.contiguous()
         beta = beta.contiguous() if beta is not None else None
         rowscale = rowscale.contiguous() if rowscale is not None else None
         colscale = colscale.contiguous() if colscale is not None else None
         zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
-            x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
+            x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
             residual_in_fp32, is_rms_norm
         )
         # Only need to save x0 if we need to compute gradient wrt colscale
@@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
         ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
         ctx.prenorm = prenorm
         ctx.dropout_p = dropout_p
-        ctx.has_residual = x1 is not None
+        ctx.has_residual = residual is not None
         ctx.is_rms_norm = is_rms_norm
         ctx.has_beta = beta is not None
         if not return_dmask:
@@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
         # x0 is None if colscale is None
         dropout_p = ctx.dropout_p
         has_residual = ctx.has_residual
-        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
+        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
             dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
             ctx.is_rms_norm
         )
         dx0 = dx0mat.view(x.shape)
-        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
+        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
         dcolscale = rest[0] if colscale is not None else None
-        return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
-                None, None, None, None)
+        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
+                None, None, None, None, None)
 
 
 class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
+    def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
                 rowscale_const, out_numrows, residual_in_fp32=False,
                 prenorm=False, is_rms_norm=False, return_dmask=False):
         x0 = x0.contiguous()
-        x1 = x1.contiguous() if x1 is not None else None
+        residual = residual.contiguous() if residual is not None else None
         gamma = gamma.contiguous()
         beta = beta.contiguous() if beta is not None else None
         colscale = colscale.contiguous() if colscale is not None else None
         zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
-            x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
+            x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
             rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
         )
         # Only need to save x0 if we need to compute gradient wrt colscale
@@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
         ctx.dropout_p = dropout_p
         ctx.rowscale_const = rowscale_const
         ctx.x0_numrows = x0.shape[:-1].numel()
-        ctx.has_residual = x1 is not None
+        ctx.has_residual = residual is not None
         ctx.is_rms_norm = is_rms_norm
         ctx.has_beta = beta is not None
         z_shape = (-1, *x0.shape[1:])
@@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
         # x0 is None if colscale is None
         dropout_p = ctx.dropout_p
         has_residual = ctx.has_residual
-        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
+        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
             dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
             ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
         )
         dx0 = dx0mat.view(-1, *x.shape[1:])
-        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
+        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
         dcolscale = rest[0] if colscale is not None else None
-        return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
-                None, None, None, None, None, None, None)
+        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
+                None, None, None, None, None, None, None, None)
 
 
 def layer_norm(x, weight, bias, epsilon):
     return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
 
 
-def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
-                           prenorm=False, residual_in_fp32=False,
+def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
+                           layerscale=None, prenorm=False, residual_in_fp32=False,
                            return_dropout_mask=False):
-    """residual_in_fp32 only has an effect if x1 is None.
-    Otherwise residual dtype is x1.dtype.
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
     """
     return DropoutAddLayerNormFn.apply(
-        x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
+        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
         False, return_dropout_mask
     )
 
 
-def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
+def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
                                   x0_subset=None, out_subset=None, rowscale_const=1.0,
                                   out_numrows=0, prenorm=False, residual_in_fp32=False,
                                   return_dropout_mask=False):
-    """residual_in_fp32 only has an effect if x1 is None.
-    Otherwise residual dtype is x1.dtype.
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
     """
     return DropoutAddLayerNormSubsetFn.apply(
-        x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
+        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
         rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
     )
 
@@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
         init.ones_(self.weight)
         init.zeros_(self.bias)
 
-    def forward(self, x0, x1=None):
-        return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
+    def forward(self, x0, residual=None):
+        return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
                                       self.p if self.training else 0.0, self.epsilon,
                                       prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

+ 12 - 11
flash_attn/ops/rms_norm.py

@@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon):
                                        False, True)
 
 
-def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
-                         prenorm=False, residual_in_fp32=False, return_dropout_mask=False):
-    """residual_in_fp32 only has an effect if x1 is None.
-    Otherwise residual dtype is x1.dtype.
+def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
+                         layerscale=None, prenorm=False, residual_in_fp32=False,
+                         return_dropout_mask=False):
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
     """
     return DropoutAddLayerNormFn.apply(
-        x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
+        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
         True, return_dropout_mask
     )
 
 
-def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
+def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
                                   x0_subset=None, out_subset=None, rowscale_const=1.0,
                                   out_numrows=0, prenorm=False, residual_in_fp32=False,
                                   return_dropout_mask=False):
-    """residual_in_fp32 only has an effect if x1 is None.
-    Otherwise residual dtype is x1.dtype.
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
     """
     return DropoutAddLayerNormSubsetFn.apply(
-        x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
+        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
         rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
     )
 
@@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
     def reset_parameters(self):
         init.ones_(self.weight)
 
-    def forward(self, x0, x1=None):
-        return dropout_add_rms_norm(x0, x1, self.weight, None,
+    def forward(self, x0, residual=None):
+        return dropout_add_rms_norm(x0, residual, self.weight, None,
                                     self.p if self.training else 0.0, self.epsilon,
                                     prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)