Forráskód Böngészése

[LayerNorm] Fuse LayerScale

Tri Dao 2 éve
szülő
commit
ae137ed17a

+ 8 - 1
csrc/layer_norm/ln.h

@@ -40,6 +40,8 @@ struct ParamsBase {
         , mu(nullptr)
         , rs(nullptr)
         , gamma(nullptr)
+        , rowscale(nullptr)
+        , colscale(nullptr)
         , dropout_keep_p(1.f)
         , dropout_scale(1.f)
         , workspace(nullptr)
@@ -63,6 +65,7 @@ struct ParamsBase {
     void *rs;
     void *gamma;
     void *rowscale;
+    void *colscale;
 
     float inverse_cols;
 
@@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
         , dx(nullptr)
         , dbeta_part(nullptr)
         , dgamma_part(nullptr)
+        , dcolscale_part(nullptr)
         , dx0(nullptr)
         , dx1(nullptr)
         , dbeta(nullptr)
         , dgamma(nullptr)
+        , dcolscale(nullptr)
     {
     }
 
@@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
     // Workspace for Wgrad pre-reduction.
     void *dbeta_part;
     void *dgamma_part;
+    void *dcolscale_part;
 
     // Output: Dgrad.
     void *dx0;
@@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
     // Output: Wgrad.
     void *dbeta;
     void *dgamma;
+    void *dcolscale;
 
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
-using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
+using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
 using FunctionKey = uint64_t;
 using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
 using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;

+ 58 - 133
csrc/layer_norm/ln_api.cpp

@@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
                                            const at::Tensor &gamma,   // hidden_size
                                            const at::Tensor &beta,   // hidden_size
                                            c10::optional<const at::Tensor> &rowscale_,      // BxS
+                                           c10::optional<const at::Tensor> &colscale_,      // BxS
                                            const float dropout_p,
                                            const float epsilon,
                                            c10::optional<at::Generator> gen_,
@@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
         TORCH_CHECK(rowscale.is_cuda())
         TORCH_CHECK(rowscale.is_contiguous());
         TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
-        TORCH_CHECK(rowscale.scalar_type() == itype);
+        TORCH_CHECK(rowscale.dtype() == itype);
+    }
+
+    if (colscale_.has_value()) {
+        auto colscale = colscale_.value();
+        TORCH_CHECK(colscale.is_cuda())
+        TORCH_CHECK(colscale.is_contiguous());
+        TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
+        TORCH_CHECK(colscale.dtype() == wtype);
     }
 
     TORCH_CHECK(gamma.sizes() == beta.sizes());
@@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 
     auto opts = x0.options();
 
-    bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
+    bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype);
     at::Tensor x;
     if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
     at::Tensor dmask;
@@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     launch_params.params.dropout_keep_p = 1.f - dropout_p;
     launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
+    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
 
     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
         gen_, at::cuda::detail::getDefaultCUDAGenerator());
@@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidden_size
+                                           c10::optional<const at::Tensor> &dx_,     // BxSxhidden_size
                                            const at::Tensor &x,      // BxSxhidden_size
+                                           c10::optional<const at::Tensor> &x0_,     // BxSxhidden_size
                                            c10::optional<const at::Tensor> &dmask_,  // BxSxhidden_size
                                            const at::Tensor &mu,     // BxS, FP32!
                                            const at::Tensor &rsigma, // BxS, FP32!
                                            const at::Tensor &gamma,   // hidden_size
                                            c10::optional<const at::Tensor> &rowscale_,      // BxS
+                                           c10::optional<const at::Tensor> &colscale_,      // BxS
                                            const float dropout_p,
                                            const bool has_residual
 ) {
@@ -250,133 +263,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     auto rows = sizes[0];
     auto cols = sizes[1];
 
-    if (dmask_.has_value()) {
-        auto dmask = dmask_.value();
-        TORCH_CHECK(dmask.dtype() == mtype);
-        TORCH_CHECK(dmask.is_cuda());
-        TORCH_CHECK(dmask.is_contiguous());
-        TORCH_CHECK(dmask.sizes() == sizes);
-    }
-
-    if (rowscale_.has_value()) {
-        auto rowscale = rowscale_.value();
-        TORCH_CHECK(rowscale.is_cuda())
-        TORCH_CHECK(rowscale.is_contiguous());
-        TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
-        TORCH_CHECK(rowscale.scalar_type() == itype);
-    }
-
-    auto hidden_size = gamma.numel();
-    TORCH_CHECK(hidden_size == cols);
-    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
-
-    TORCH_CHECK(mu.numel() == rows);
-    TORCH_CHECK(mu.sizes() == rsigma.sizes());
-
-    TORCH_CHECK(gamma.numel() == cols);
-
-    auto opts = x.options();
-
-    auto dx0 = torch::empty_like(x, opts.dtype(itype));
-    at::Tensor dx1;
-    if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
-    auto dgamma = torch::empty_like(gamma);
-    auto dbeta = torch::empty_like(gamma);
-
-    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
-    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
-    launch_params.props = at::cuda::getCurrentDeviceProperties();
-    TORCH_CHECK(dropout_p < 1.f);
-    launch_params.params.dropout_keep_p = 1.f - dropout_p;
-    launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
-    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
-
-    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
-    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
-
-    launcher(launch_params, true, /*prenorm=*/false);
-
-    auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
-    auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
-    at::Tensor workspace, barrier;
-
-    layer_norm::BwdParams &params = launch_params.params;
-    params.rows = rows;
-    params.cols = cols;
-    params.x = x.data_ptr();
-    params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
-    params.mu = mu.data_ptr();
-    params.rs = rsigma.data_ptr();
-    params.gamma = gamma.data_ptr();
-    params.dz = dz.data_ptr();
-    params.dx0 = dx0.data_ptr();
-    params.dbeta = dbeta.data_ptr();
-    params.dgamma = dgamma.data_ptr();
-    params.dbeta_part = dbeta_part.data_ptr();
-    params.dgamma_part = dgamma_part.data_ptr();
-    params.dropout_scale = 1.f / (1.f - dropout_p);
-    params.inverse_cols = 1.f / float(params.cols);
-
-    if( launch_params.barrier_size > 0 ) {
-        // TODO Any way to avoid this?
-        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
-        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
-        params.workspace = workspace.data_ptr();
-        params.barrier = barrier.data_ptr<int>();
+    if (dx_.has_value()) {
+        auto dx = dx_.value();
+        TORCH_CHECK(dx.dtype() == rtype);
+        TORCH_CHECK(dx.is_cuda())
+        TORCH_CHECK(dx.is_contiguous());
+        TORCH_CHECK(dx.sizes() == sizes);
     }
 
-    launcher(launch_params, false, /*prenorm=*/false);
-
-    return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     // BxSxhidden_size
-                                                   const at::Tensor &dx,     // BxSxhidden_size
-                                                   const at::Tensor &x,      // BxSxhidden_size
-                                                   c10::optional<const at::Tensor> &dmask_,  // BxSxhidden_size
-                                                   const at::Tensor &mu,     // BxS, FP32!
-                                                   const at::Tensor &rsigma, // BxS, FP32!
-                                                   const at::Tensor &gamma,   // hidden_size
-                                                   c10::optional<const at::Tensor> &rowscale_,      // BxS
-                                                   const float dropout_p,
-                                                   const bool has_residual
-) {
-
-    auto itype = dz.scalar_type();
-    auto rtype = x.scalar_type();
-    auto wtype = gamma.scalar_type();
-    auto otype = itype;
-    auto ctype = torch::kFloat32;
-    auto mtype = torch::kUInt8;
-
-    if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
-
-    TORCH_CHECK(dz.dtype() == otype);
-    TORCH_CHECK(dx.dtype() == rtype);
-    TORCH_CHECK(mu.dtype() == ctype);
-    TORCH_CHECK(rsigma.dtype() == ctype);
-
-    TORCH_CHECK(x.is_cuda());
-    TORCH_CHECK(dz.is_cuda());
-    TORCH_CHECK(dx.is_cuda());
-    TORCH_CHECK(mu.is_cuda());
-    TORCH_CHECK(rsigma.is_cuda());
-    TORCH_CHECK(gamma.is_cuda());
-
-    TORCH_CHECK(x.is_contiguous());
-    TORCH_CHECK(dz.is_contiguous());
-    TORCH_CHECK(dx.is_contiguous());
-
-    auto sizes = x.sizes();
-    TORCH_CHECK(sizes.size() == 2);
-    TORCH_CHECK(dz.sizes() == sizes);
-    TORCH_CHECK(dx.sizes() == sizes);
-    auto rows = sizes[0];
-    auto cols = sizes[1];
-
     if (dmask_.has_value()) {
         auto dmask = dmask_.value();
         TORCH_CHECK(dmask.dtype() == mtype);
@@ -390,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
         TORCH_CHECK(rowscale.is_cuda())
         TORCH_CHECK(rowscale.is_contiguous());
         TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
-        TORCH_CHECK(rowscale.scalar_type() == itype);
+        TORCH_CHECK(rowscale.dtype() == itype);
+    }
+
+    if (colscale_.has_value()) {
+        auto colscale = colscale_.value();
+        TORCH_CHECK(colscale.is_cuda())
+        TORCH_CHECK(colscale.is_contiguous());
+        TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
+        TORCH_CHECK(colscale.dtype() == wtype);
+
+        TORCH_CHECK(x0_.has_value());
+        auto x0 = x0_.value();
+        TORCH_CHECK(x0.is_cuda())
+        TORCH_CHECK(x0.is_contiguous());
+        TORCH_CHECK(x0.sizes() == sizes);
+        TORCH_CHECK(x0.dtype() == itype);
     }
 
     auto hidden_size = gamma.numel();
@@ -409,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
     if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
     auto dgamma = torch::empty_like(gamma);
     auto dbeta = torch::empty_like(gamma);
+    at::Tensor dcolscale;
+    if (colscale_.has_value()) {
+        dcolscale = torch::empty_like(colscale_.value());
+    }
 
     layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
     launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
@@ -417,32 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
     launch_params.params.dropout_keep_p = 1.f - dropout_p;
     launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
+    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
     auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
-    launcher(launch_params, true, /*prenorm=*/true);
+    launcher(launch_params, true);
 
     auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
     auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    at::Tensor dcolscale_part;
+    if (colscale_.has_value()) {
+        dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    }
     at::Tensor workspace, barrier;
 
     layer_norm::BwdParams &params = launch_params.params;
     params.rows = rows;
     params.cols = cols;
     params.x = x.data_ptr();
+    params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
     params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
     params.mu = mu.data_ptr();
     params.rs = rsigma.data_ptr();
     params.gamma = gamma.data_ptr();
     params.dz = dz.data_ptr();
-    params.dx = dx.data_ptr();
+    params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
     params.dx0 = dx0.data_ptr();
     params.dbeta = dbeta.data_ptr();
     params.dgamma = dgamma.data_ptr();
+    params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
     params.dbeta_part = dbeta_part.data_ptr();
     params.dgamma_part = dgamma_part.data_ptr();
+    params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
     params.dropout_scale = 1.f / (1.f - dropout_p);
     params.inverse_cols = 1.f / float(params.cols);
 
@@ -454,9 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
         params.barrier = barrier.data_ptr<int>();
     }
 
-    launcher(launch_params, false, /*prenorm=*/true);
+    launcher(launch_params, false);
 
-    return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
+    std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
+    if (colscale_.has_value()) {
+        result.push_back(dcolscale);
+        result.push_back(dcolscale_part);
+    }
+    return result;
 }
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
@@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   m.doc() = "CUDA DropoutAddLayerNorm";
   m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
   m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
-  m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
 }

+ 118 - 57
csrc/layer_norm/ln_bwd_kernels.cuh

@@ -7,7 +7,7 @@
 
 namespace layer_norm {
 
-template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
+template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
 __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
 void ln_bwd_kernel(layer_norm::BwdParams params) {
 
@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
 
     Cvec dzy_sum[LDGS];
     Cvec dz_sum[LDGS];
+    Cvec dcolscale_sum[LDGS];
 
     memset(dzy_sum, 0, sizeof(dzy_sum));
     memset(dz_sum, 0, sizeof(dz_sum));
+    if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
 
     compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
     char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
         ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
 
     Wvec gamma[LDGS];
+    Wvec colscale[LDGS];
     index_t idx = c;
     #pragma unroll
     for( int it = 0; it < LDGS; it++ ) {
         if (Is_even_cols || (it < num_valid_ldgs)) {
             gamma[it].load_from(params.gamma, idx);
+            if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
             idx += Ktraits::VEC_COLS_PER_LDG;
         }
     }
@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
             if (Is_even_cols || (it < num_valid_ldgs)) {
                 Ivec dx0;
                 Rvec dx1;
+                Ivec x0;
+                if (Has_colscale) { x0.load_from(params.x0, idx); }
                 #pragma unroll
                 for( int jt = 0; jt < NUM_ELTS; jt++ ) {
                     compute_t dy_tmp = dy[it * NUM_ELTS + jt];
@@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
                     if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
                     compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
                     if (Is_dropout) {
-                        dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
+                        dx0_tmp_res *= params.dropout_scale;
+                        if (Has_colscale) {
+                            dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
+                            dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
+                        } else {
+                            dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
+                        }
                     } else {
-                        dx0.data.elt[jt] = dx0_tmp_res;
+                        if (Has_colscale) {
+                            dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
+                            dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
+                        } else {
+                            dx0.data.elt[jt] = dx0_tmp_res;
+                        }
                     }
                 }
                 if (Has_residual) { dx1.store_to(params.dx1, idx); }
@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
             if (Is_even_cols || (it < num_valid_ldgs)) {
                 dz_sum[it].store_to(params.dbeta_part, idx);
                 dzy_sum[it].store_to(params.dgamma_part, idx);
+                if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
                 idx += Ktraits::VEC_COLS_PER_LDG;
             }
         }
@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
             }
         }
 
+        compute_t cta_dcolscale_sum[NUM_RES];
+        if (Has_colscale) {
+            __syncthreads();
+            idx = warp_m * Ktraits::VEC_COLS + tid_r;
+            #pragma unroll
+            for( int it = 0; it < LDGS; it++ ) {
+                dcolscale_sum[it].store_to(smem_wgrad, idx);
+                idx += THREADS_PER_ROW;
+            }
+            __syncthreads();
+            memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
+            for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+                for( int jt = 0; jt < NUM_RES; jt++ ) {
+                    cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+                }
+            }
+        }
+
         const index_t num_valid_writes
             = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
         compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
         compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
+        compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
         for( int jt = 0; jt < NUM_RES; jt++ ) {
             if (Is_even_cols || (jt < num_valid_writes)) {
                 *dgamma_part = cta_dzy_sum[jt];
                 dgamma_part += Ktraits::THREADS_PER_CTA;
                 *dbeta_part = cta_dz_sum[jt];
                 dbeta_part += Ktraits::THREADS_PER_CTA;
+                if (Has_colscale) {
+                    *dcolscale_part = cta_dcolscale_sum[jt];
+                    dcolscale_part += Ktraits::THREADS_PER_CTA;
+                }
             }
         }
 
     }
 }
 
-template<typename Kernel_traits, bool Is_even_cols>
+template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
 __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
 void ln_bwd_finalize_kernel(BwdParams params)
 {
@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
     constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
     for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
         // Each thread sums over NUM_ELT columns.
-        Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
+        Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
         memset(&dgamma_local, 0, sizeof(dgamma_local));
         memset(&dbeta_local, 0, sizeof(dbeta_local));
+        if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
         if (Is_even_cols || col < params.cols) {
             for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
-                // index_t idx = row * Kernel_traits::COLS + col;
                 index_t idx = row * params.cols + col;
 
-                Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
+                Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
                 dbeta_part.load_from(params.dbeta_part, idx);
                 dgamma_part.load_from(params.dgamma_part, idx);
+                if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
                 #pragma unroll
                 for( int it = 0; it < NUM_ELT; it++ ) {
                     dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
                     dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
+                    if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
                 }
             }
         }
         void * smem_gamma = smem_;
         void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
 
         const int write_row = warp;
         const int write_col = lane ^ write_row;
@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
 
         dgamma_local.store_to(smem_gamma, write_idx);
         dbeta_local.store_to(smem_beta, write_idx);
+        if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
 
         __syncthreads();
 
         // It would be probably safe to reuse the first row of smem_beta and smem_gamma
-        void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
-        void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
+        void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
+        void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
 
 
         // More than one iter iff ROWS_PER_CTA < 32.
@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
 
             memset(&dbeta_local, 0, sizeof(dbeta_local));
             memset(&dgamma_local, 0, sizeof(dgamma_local));
+            if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
 
             // Load beta and gamma transposed 
             if(read_row < Kernel_traits::ROWS_PER_CTA){
                 dbeta_local.load_from(smem_beta, read_idx);
                 dgamma_local.load_from(smem_gamma, read_idx);
+                if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
             }
 
             // Call reducer on the loaded value(s) and convert.
@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
 
                 dgamma_local.data.elt[it] = g_i;
                 dbeta_local.data.elt[it] = b_i;
+                if (Has_colscale) {
+                    compute_t cs_i = dcolscale_local.data.elt[it];
+                    cs_i = reducer.allreduce(cs_i, sum);
+                    dcolscale_local.data.elt[it] = cs_i;
+                }
             }
 
             // Leader stores the result at the current column.
             if(lane == 0){
                 dgamma_local.store_to(smem_gamma_out, w);
                 dbeta_local.store_to(smem_beta_out, w);
+                if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
             }
 
         }
@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
 
                 using src_t = typename TypeToVec2<compute_t>::Type;
                 using dst_t = typename TypeToVec2<weight_t>::Type;
-                Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
-                Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
+                Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
+                Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
 
                 dgamma_vec2.load_from(smem_gamma_out, lane);
                 dbeta_vec2.load_from(smem_beta_out, lane);
+                if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
                 #pragma unroll
                 for( int it = 0; it < NUM_ELT; it++ ) {
                     dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
                     dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
+                    if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
                 }
                 dgamma_out2.store_to(params.dgamma, col_out);
                 dbeta_out2.store_to(params.dbeta, col_out);
-
+                if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
             }
         }
     }
@@ -364,7 +420,7 @@ template<
     int BYTES_PER_LDG_MAIN,
     int BYTES_PER_LDG_FINAL
 >
-void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
+void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
 
     using Kernel_traits = Kernel_traits<weight_t,
                                         input_t,
@@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
                                         WARPS_N,
                                         BYTES_PER_LDG_MAIN
                                         >;
+    bool prenorm = launch_params.params.dx != nullptr;
     bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
     bool has_residual = launch_params.params.dx1 != nullptr;
+    bool has_colscale = launch_params.params.colscale != nullptr;
     bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
     BOOL_SWITCH(prenorm, PrenormConst, [&] {
         BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
             BOOL_SWITCH(has_residual, HasResidualConst, [&] {
-                BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
-                    auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
-                    if( configure_params ) {
-                        int ctas_per_sm;
-                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
-                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
-                        launch_params.barrier_size = 0;
-                        launch_params.workspace_bytes = 0;
-                        if(Kernel_traits::CTAS_PER_ROW > 1) {
-                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
-                            launch_params.workspace_bytes = launch_params.params.ctas_per_col
-                                                          * Kernel_traits::WARPS_M
-                                                          * Kernel_traits::CTAS_PER_ROW
-                                                          * sizeof(typename Kernel_traits::reduce_t)
-                                                          * 2;
+                BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
+                    BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                        auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
+                        if( configure_params ) {
+                            int ctas_per_sm;
+                            CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                                &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
+                            launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                            launch_params.barrier_size = 0;
+                            launch_params.workspace_bytes = 0;
+                            if(Kernel_traits::CTAS_PER_ROW > 1) {
+                                launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                                launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                              * Kernel_traits::WARPS_M
+                                                              * Kernel_traits::CTAS_PER_ROW
+                                                              * sizeof(typename Kernel_traits::reduce_t)
+                                                              * 2;
+                            }
+                            return;
                         }
-                        return;
-                    }
-
-                    if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
-                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
-                    }
-                    auto stream = launch_params.stream;
-                    auto ctas_per_col = launch_params.params.ctas_per_col;
 
-                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {
-                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
-                    } else {
-                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
-                        dim3 block(Kernel_traits::THREADS_PER_CTA);
-                        void *params_ = (void *)&launch_params.params;
-                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
-                    }
+                        if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
+                            CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
+                        }
+                        auto stream = launch_params.stream;
+                        auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                        if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                            kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
+                        } else {
+                            dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                            dim3 block(Kernel_traits::THREADS_PER_CTA);
+                            void *params_ = (void *)&launch_params.params;
+                            cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
+                        }
 
-                    using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
-                                                                              weight_t,
-                                                                              input_t,
-                                                                              residual_t,
-                                                                              output_t,
-                                                                              compute_t,
-                                                                              index_t,
-                                                                              32 * 32,  // THREADS_PER_CTA
-                                                                              BYTES_PER_LDG_FINAL>;
-
-                    auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
-                    kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
+                        using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
+                                                                                  weight_t,
+                                                                                  input_t,
+                                                                                  residual_t,
+                                                                                  output_t,
+                                                                                  compute_t,
+                                                                                  index_t,
+                                                                                  HasColscaleConst,
+                                                                                  32 * 32,  // THREADS_PER_CTA
+                                                                                  BYTES_PER_LDG_FINAL>;
+
+                        auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
+                        kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
+                    });
                 });
             });
         });

+ 45 - 44
csrc/layer_norm/ln_fwd_kernels.cuh

@@ -16,7 +16,7 @@
 
 namespace layer_norm {
 
-template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols>
+template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
 __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
 void ln_fwd_kernel(FwdParams params) {
 
@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
     using Stats = typename Ktraits::Stats;
     using stats_t = typename Stats::stats_t;
 
-    constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
+    const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value);
 
     extern __shared__ char smem_[];
 
@@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
 
     Wvec gamma[LDGS];
     Wvec beta[LDGS];
+    Wvec colscale[LDGS];
     index_t idx = c;
     #pragma unroll
     for( int it = 0; it < LDGS; it++ ) {
         if (Is_even_cols || (it < num_valid_ldgs)) {
             gamma[it].load_from(params.gamma, idx);
             beta[it].load_from(params.beta, idx);
+            if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
             idx += VEC_COLS_PER_LDG;
         }
     }
@@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
                     // the more efficient curand_uniform4.
                     mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
                     compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
-                    compute_t x_ij;
-                    if (Has_residual) {
-                        compute_t x1_ij = compute_t(x1.data.elt[jt]);
-                        x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
-                    } else  {
-                        x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
-                    }
+                    x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
+                    if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
+                    compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
                     if (save_x) { x.data.elt[jt] = x_ij; }
                     xf[it * NUM_ELTS + jt] = x_ij;
                     if (Is_dropout) { dmask.data.elt[jt] = keep; }
@@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
         const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
         const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
         const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
-        // Need to convert to int, otherwise the subtraction will wrap around.
         auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
+            // Need to convert to int, otherwise the subtraction will wrap around.
             const index_t valid_partial_vecs_in_warp =
                 std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
                         int(THREADS_PER_WARP));
@@ -206,45 +204,48 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
                                         BYTES_PER_LDG
                                         >;
     bool has_residual = launch_params.params.x1 != nullptr;
+    bool has_colscale = launch_params.params.colscale != nullptr;
     bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
     BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
         BOOL_SWITCH(has_residual, HasResidualConst, [&] {
-            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
-                auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
-                if( configure_params ) {
-                    int ctas_per_sm;
-                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
-                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
-                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
-                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
-                    launch_params.barrier_size = 0;
-                    launch_params.workspace_bytes = 0;
-                    if(Kernel_traits::CTAS_PER_ROW > 1) {
-                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
-                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
-                                                      * Kernel_traits::WARPS_M
-                                                      * Kernel_traits::CTAS_PER_ROW
-                                                      * sizeof(typename Kernel_traits::Stats::stats_t)
-                                                      * 2;
+            BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
+                BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                    auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
+                    if( configure_params ) {
+                        int ctas_per_sm;
+                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
+                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                        const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
+                        launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
+                        launch_params.barrier_size = 0;
+                        launch_params.workspace_bytes = 0;
+                        if(Kernel_traits::CTAS_PER_ROW > 1) {
+                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                            launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                          * Kernel_traits::WARPS_M
+                                                          * Kernel_traits::CTAS_PER_ROW
+                                                          * sizeof(typename Kernel_traits::Stats::stats_t)
+                                                          * 2;
+                        }
+                        return;
                     }
-                    return;
-                }
 
-                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
-                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
-                }
-                auto stream = launch_params.stream;
-                auto ctas_per_col = launch_params.params.ctas_per_col;
-
-                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
-                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
-                } else {
-                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
-                    dim3 block(Kernel_traits::THREADS_PER_CTA);
-                    void *params_ = (void *)&launch_params.params;
-                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
-                }
+                    if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
+                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
+                    }
+                    auto stream = launch_params.stream;
+                    auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
+                    } else {
+                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                        dim3 block(Kernel_traits::THREADS_PER_CTA);
+                        void *params_ = (void *)&launch_params.params;
+                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
+                    }
+                });
             });
         });
     });

+ 3 - 1
csrc/layer_norm/ln_kernel_traits.h

@@ -38,6 +38,7 @@ template<
     typename output_t_,
     typename compute_t_,
     typename index_t_,
+    bool Has_colscale,
     uint32_t THREADS_PER_CTA_,
     uint32_t BYTES_PER_LDG_,
     typename Base = Kernel_traits_base<HIDDEN_SIZE_,
@@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
     // Shared memory size to coalsece the CTA result.
     enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
     // Shared memory requirement per CTA. 
-    enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
+    static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
+    enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
 
     // The type of the reducer.
     using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;

+ 2 - 2
csrc/layer_norm/ln_utils.cuh

@@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
 #define REGISTER_BWD_LAUNCHER(                                                                                                                  \
     HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE)                      \
     void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,                         \
-                                                                                const bool configure_params, const bool prenorm) {              \
+                                                                                const bool configure_params) {                                  \
         launch_<WTYPE,                                                                                                                          \
                 ITYPE,                                                                                                                          \
                 RTYPE,                                                                                                                          \
@@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
                 WARPS_M,                                                                                                                        \
                 WARPS_N,                                                                                                                        \
                 BYTES_PER_LDG,                                                                                                                  \
-                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm);                                                              \
+                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                                       \
     }                                                                                                                                           \
     static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(    \
         ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)

+ 47 - 80
flash_attn/ops/layer_norm.py

@@ -1,11 +1,13 @@
+# Copyright (c) 2022, Tri Dao.
 # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
+
 import torch
 from torch.nn import init
 
 import dropout_layer_norm
 
 
-def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
+def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
                                     residual_in_fp32):
     """ Assume that arguments are contiguous
     """
@@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
     x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
     rowscale = rowscale.view(-1) if rowscale is not None else None
     zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
-        x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
+        x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
     )
     # dmask is None if dropout_p == 0.0
     # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
     return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
 
 
-def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
-                                     has_residual):
+def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
+                                     dropout_p, has_residual):
     """ Assume that arguments are contiguous
+    dx == None means that it was a post-norm architecture
+    (x = drop(x0) + x1 was not returned in the fwd).
+    x0 must not be None if we have colscale.
     """
-    # dmask is None if dropout_p == 0.0
     hidden_size = gamma.numel()
     xmat = x.view((-1, hidden_size))
     dzmat = dz.view(xmat.shape)
+    dxmat = dx.view(xmat.shape) if dx is not None else None
+    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
     rowscale = rowscale.view(-1) if rowscale is not None else None
-    dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
-        dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+    colscale = colscale.view(-1) if colscale is not None else None
+    if colscale is not None:
+        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
+    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
+        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
+        has_residual
     )
     # dx1mat is None if not has_residual
-    return dx0mat, dx1mat, dgamma, dbeta
-
-
-def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
-                                             dropout_p, has_residual):
-    """ Assume that arguments are contiguous
-    """
-    hidden_size = gamma.numel()
-    xmat = x.view((-1, hidden_size))
-    dzmat = dz.view(xmat.shape)
-    dxmat = dx.view(xmat.shape)
-    rowscale = rowscale.view(-1) if rowscale is not None else None
-    dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
-        dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
-    )
-    return dx0mat, dx1mat, dgamma, dbeta
+    if colscale is None:
+        return dx0mat, dx1mat, dgamma, dbeta
+    else:
+        dcolscale = rest[0]
+        return dx0mat, dx1mat, dgamma, dbeta, dcolscale
 
 
-class DropoutAddLayerNormFN(torch.autograd.Function):
+class DropoutAddLayerNormFn(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
-                return_dmask=False):
+    def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
+                prenorm=False, return_dmask=False):
         x0 = x0.contiguous()
         x1 = x1.contiguous() if x1 is not None else None
         gamma = gamma.contiguous()
         beta = beta.contiguous()
         rowscale = rowscale.contiguous() if rowscale is not None else None
+        colscale = colscale.contiguous() if colscale is not None else None
         zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
-            x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
+            x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
         )
-        ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
+        # Only need to save x0 if we need to compute gradient wrt colscale
+        x0_saved = x0 if colscale is not None else None
+        ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
+        ctx.prenorm = prenorm
         ctx.dropout_p = dropout_p
         ctx.has_residual = x1 is not None
         if not return_dmask:
-            return zmat.view(x0.shape)
+            return (zmat.view(x0.shape) if not prenorm
+                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
         else:
             dmask = (dmask.view(x0.shape) if dropout_p > 0.
                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
             ctx.mark_non_differentiable(dmask)
-            return zmat.view(x0.shape), dmask
+            return ((zmat.view(x0.shape), dmask) if not prenorm
+                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
 
     @staticmethod
     def backward(ctx, dz, *args):
         # assert dz.is_contiguous()
         dz = dz.contiguous()  # this happens!
-        x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
-        dropout_p = ctx.dropout_p
-        has_residual = ctx.has_residual
-        dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
-            dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
-        )
-        dx0 = dx0mat.view(x.shape)
-        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
-        return dx0, dx1, dgamma, dbeta, None, None, None, None, None
-
-
-class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
-                return_dmask=False):
-        x0 = x0.contiguous()
-        x1 = x1.contiguous() if x1 is not None else None
-        gamma = gamma.contiguous()
-        beta = beta.contiguous()
-        rowscale = rowscale.contiguous() if rowscale is not None else None
-        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
-            x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
-        )
-        ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
-        ctx.dropout_p = dropout_p
-        ctx.has_residual = x1 is not None
-        if not return_dmask:
-            return zmat.view(x0.shape), xmat.view(x0.shape)
-        else:
-            dmask = (dmask.view(x0.shape) if dropout_p > 0.
-                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
-            ctx.mark_non_differentiable(dmask)
-            return zmat.view(x0.shape), xmat.view(x0.shape), dmask
-
-    @staticmethod
-    def backward(ctx, dz, dx, *args):
-        # assert dz.is_contiguous()
-        dz = dz.contiguous()  # this happens!
-        dx = dx.contiguous()  # this happens!
-        x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
+        dx = args[0].contiguous() if ctx.prenorm else None
+        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
+        # x0 is None if colscale is None
         dropout_p = ctx.dropout_p
         has_residual = ctx.has_residual
-        dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
-            dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
+            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
         )
         dx0 = dx0mat.view(x.shape)
         dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
-        return dx0, dx1, dgamma, dbeta, None, None, None, None, None
+        dcolscale = rest[0] if colscale is not None else None
+        return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
 
 
-def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
+def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
                            prenorm=False, residual_in_fp32=False,
                            return_dropout_mask=False):
     """residual_in_fp32 only has an effect if x1 is None.
     Otherwise residual dtype is x1.dtype.
     """
-    args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
-            return_dropout_mask)
-    if not prenorm:
-        return DropoutAddLayerNormFN.apply(*args)
-    else:
-        return DropoutAddLayerNormPrenormFN.apply(*args)
+    return DropoutAddLayerNormFn.apply(
+        x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
+        return_dropout_mask
+    )
 
 
 class DropoutAddLayerNorm(torch.nn.Module):

+ 31 - 12
tests/ops/test_dropout_layer_norm.py

@@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
 
 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
 
+@pytest.mark.parametrize('has_colscale', [True, False])
 @pytest.mark.parametrize('has_rowscale', [True, False])
 # @pytest.mark.parametrize('has_rowscale', [True])
 @pytest.mark.parametrize('has_residual', [True, False])
@@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
 # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
 @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
 def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
-                                     dropout_p, has_residual, has_rowscale):
+                                     dropout_p, has_residual, has_rowscale, has_colscale):
     if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
         pytest.skip()  # Not supported
-    # Backward numerical error is high, and this case isn't used
-    if has_rowscale and not has_residual:
-        pytest.skip()
     device = 'cuda'
     # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
     rtol, atol = (1e-3, 1e-4)
@@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
                         requires_grad=True)
     x0 = x0_pt.detach().clone().requires_grad_()
     x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_colscale:
+        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+        colscale_pt = colscale.detach().clone().requires_grad_()
+        colscale_ref = colscale.detach().clone().float().requires_grad_()
+    else:
+        colscale = None
     if has_residual:
         x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
         x1 = x1_pt.detach().clone().requires_grad_()
@@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
         rowscale = None
         x0_scaled_pt = x0_pt
         x0_scaled_ref = x0_ref
+    if has_colscale:
+        x0_scaled_pt = x0_scaled_pt * colscale_pt
+        x0_scaled_ref = x0_scaled_ref * colscale_ref
     model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
     torch.nn.init.normal_(model_pt.weight)
     torch.nn.init.normal_(model_pt.bias)
@@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
         model_ref.bias.copy_(model_pt.bias)
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
     out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
-                                        model.epsilon, rowscale=rowscale,
+                                        model.epsilon, rowscale=rowscale, layerscale=colscale,
                                         residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
     assert out.dtype == input_dtype
     print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
@@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
         assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
     assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
     assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
+    if has_colscale:
+        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
 
 
 @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
     assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
 
 
+@pytest.mark.parametrize('has_colscale', [True, False])
 @pytest.mark.parametrize('has_rowscale', [True, False])
 @pytest.mark.parametrize('has_residual', [True, False])
 @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
                          [(torch.float16, torch.float16), (torch.float16, torch.float32),
                           (torch.float32, torch.float32)]
                          + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+# @pytest.mark.parametrize('has_colscale', [True])
 # @pytest.mark.parametrize('has_rowscale', [False])
-# @pytest.mark.parametrize('has_residual', [True])
+# @pytest.mark.parametrize('has_residual', [False])
 # @pytest.mark.parametrize('dropout_p', [0.0])
 # @pytest.mark.parametrize('weight_dtype', [torch.float32])
 # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
-# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
 @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
 def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
-                                             dropout_p, has_residual, has_rowscale):
+                                             dropout_p, has_residual, has_rowscale, has_colscale):
     if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
         pytest.skip()  # Not supported
-    # Backward numerical error is high, and this case isn't used
-    if has_rowscale and not has_residual:
-        pytest.skip()
     device = 'cuda'
     # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
     rtol, atol = (1e-3, 2e-4)
@@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
                         requires_grad=True)
     x0 = x0_pt.detach().clone().requires_grad_()
     x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_colscale:
+        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+        colscale_pt = colscale.detach().clone().requires_grad_()
+        colscale_ref = colscale.detach().clone().float().requires_grad_()
+    else:
+        colscale = None
     if has_residual:
         x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
         x1 = x1_pt.detach().clone().requires_grad_()
@@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
         rowscale = None
         x0_scaled_pt = x0_pt
         x0_scaled_ref = x0_ref
+    if has_colscale:
+        x0_scaled_pt = x0_scaled_pt * colscale_pt
+        x0_scaled_ref = x0_scaled_ref * colscale_ref
     model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
     model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
     model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
@@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
         model_ref.bias.copy_(model_pt.bias)
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
     out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
-                                                  model.epsilon, rowscale=rowscale, prenorm=True,
+                                                  model.epsilon, rowscale=rowscale,
+                                                  layerscale=colscale, prenorm=True,
                                                   residual_in_fp32=residual_in_fp32,
                                                   return_dropout_mask=True)
     print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
@@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
         assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
     assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
     assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
+    if has_colscale:
+        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
 
 
 @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])