瀏覽代碼

[LayerNorm] Implement LN with parallel residual, support dim 8k

Tri Dao 2 年之前
父節點
當前提交
393882bc08
共有 46 個文件被更改,包括 2348 次插入114 次删除
  1. 7 4
      csrc/layer_norm/README.md
  2. 43 3
      csrc/layer_norm/ln.h
  3. 408 32
      csrc/layer_norm/ln_api.cpp
  4. 15 0
      csrc/layer_norm/ln_bwd_7168.cu
  5. 15 0
      csrc/layer_norm/ln_bwd_8192.cu
  6. 15 0
      csrc/layer_norm/ln_fwd_7168.cu
  7. 15 0
      csrc/layer_norm/ln_fwd_8192.cu
  8. 15 0
      csrc/layer_norm/ln_parallel_bwd_1024.cu
  9. 15 0
      csrc/layer_norm/ln_parallel_bwd_1280.cu
  10. 15 0
      csrc/layer_norm/ln_parallel_bwd_1536.cu
  11. 15 0
      csrc/layer_norm/ln_parallel_bwd_2048.cu
  12. 15 0
      csrc/layer_norm/ln_parallel_bwd_256.cu
  13. 15 0
      csrc/layer_norm/ln_parallel_bwd_2560.cu
  14. 15 0
      csrc/layer_norm/ln_parallel_bwd_3072.cu
  15. 17 0
      csrc/layer_norm/ln_parallel_bwd_4096.cu
  16. 15 0
      csrc/layer_norm/ln_parallel_bwd_512.cu
  17. 17 0
      csrc/layer_norm/ln_parallel_bwd_5120.cu
  18. 15 0
      csrc/layer_norm/ln_parallel_bwd_6144.cu
  19. 15 0
      csrc/layer_norm/ln_parallel_bwd_7168.cu
  20. 15 0
      csrc/layer_norm/ln_parallel_bwd_768.cu
  21. 15 0
      csrc/layer_norm/ln_parallel_bwd_8192.cu
  22. 15 0
      csrc/layer_norm/ln_parallel_fwd_1024.cu
  23. 15 0
      csrc/layer_norm/ln_parallel_fwd_1280.cu
  24. 15 0
      csrc/layer_norm/ln_parallel_fwd_1536.cu
  25. 15 0
      csrc/layer_norm/ln_parallel_fwd_2048.cu
  26. 15 0
      csrc/layer_norm/ln_parallel_fwd_256.cu
  27. 15 0
      csrc/layer_norm/ln_parallel_fwd_2560.cu
  28. 15 0
      csrc/layer_norm/ln_parallel_fwd_3072.cu
  29. 15 0
      csrc/layer_norm/ln_parallel_fwd_4096.cu
  30. 15 0
      csrc/layer_norm/ln_parallel_fwd_512.cu
  31. 15 0
      csrc/layer_norm/ln_parallel_fwd_5120.cu
  32. 15 0
      csrc/layer_norm/ln_parallel_fwd_6144.cu
  33. 15 0
      csrc/layer_norm/ln_parallel_fwd_7168.cu
  34. 15 0
      csrc/layer_norm/ln_parallel_fwd_768.cu
  35. 15 0
      csrc/layer_norm/ln_parallel_fwd_8192.cu
  36. 540 0
      csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
  37. 281 0
      csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
  38. 33 0
      csrc/layer_norm/ln_utils.cuh
  39. 32 0
      csrc/layer_norm/setup.py
  40. 21 8
      flash_attn/models/gpt.py
  41. 30 15
      flash_attn/modules/block.py
  42. 109 2
      flash_attn/ops/layer_norm.py
  43. 14 0
      flash_attn/ops/rms_norm.py
  44. 1 1
      tests/models/test_gpt_neox.py
  45. 1 1
      tests/models/test_gptj.py
  46. 344 48
      tests/ops/test_dropout_layer_norm.py

+ 7 - 4
csrc/layer_norm/README.md

@@ -1,10 +1,13 @@
 This CUDA extension implements fused dropout + residual + LayerNorm, building on
 Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
-We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
-We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
-We also implement RMSNorm as an option.
+Major changes:
+- Add dropout and residual.
+- Make it work for both pre-norm and post-norm architecture.
+- Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
+- Implement RMSNorm as an option.
+- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
 
-If you want to use it for dimensions larger than 6k, please file an issue.
+If you want to use it for dimensions larger than 8k, please file an issue.
 
 This extension has only been tested on A100s.
 

+ 43 - 3
csrc/layer_norm/ln.h

@@ -14,7 +14,7 @@ namespace layer_norm {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Params> 
+template<typename Params>
 struct LaunchParams{
 
     size_t elts_per_thread;
@@ -40,6 +40,7 @@ struct ParamsBase {
         , mu(nullptr)
         , rs(nullptr)
         , gamma(nullptr)
+        , gamma1(nullptr)
         , rowscale(nullptr)
         , colscale(nullptr)
         , dropout_keep_p(1.f)
@@ -59,12 +60,15 @@ struct ParamsBase {
 
     // Common data pointers.
     void *x0;
+    void *x1;
     void *residual;
     void *x;
     void *dmask;
+    void *dmask1;
     void *mu;
     void *rs;
     void *gamma;
+    void *gamma1;
     void *rowscale;
     void *colscale;
     void *x0_subset;
@@ -92,14 +96,18 @@ struct FwdParams : public ParamsBase {
     FwdParams()
         : ParamsBase()
         , z(nullptr)
+        , z1(nullptr)
         , beta(nullptr)
+        , beta1(nullptr)
         , epsilon(0.f)
     {
     }
 
     // Output of LN FWD.
     void *z;
+    void *z1;
     void *beta;
+    void *beta1;
     float epsilon;
 
     // Random state.
@@ -112,34 +120,46 @@ struct BwdParams : public ParamsBase {
     BwdParams()
         : ParamsBase()
         , dz(nullptr)
+        , dz1(nullptr)
         , dx(nullptr)
         , dbeta_part(nullptr)
         , dgamma_part(nullptr)
+        , dbeta1_part(nullptr)
+        , dgamma1_part(nullptr)
         , dcolscale_part(nullptr)
         , dx0(nullptr)
+        , dx1(nullptr)
         , dresidual(nullptr)
         , dbeta(nullptr)
         , dgamma(nullptr)
+        , dbeta1(nullptr)
+        , dgamma1(nullptr)
         , dcolscale(nullptr)
     {
     }
 
     // Input: gradient wrt. LN FWD output.
     void *dz;
+    void *dz1;
     // Input: gradient wrt residual.
     void *dx;
 
     // Workspace for Wgrad pre-reduction.
     void *dbeta_part;
     void *dgamma_part;
+    void *dbeta1_part;
+    void *dgamma1_part;
     void *dcolscale_part;
 
     // Output: Dgrad.
     void *dx0;
+    void *dx1;
     void *dresidual;
     // Output: Wgrad.
     void *dbeta;
     void *dgamma;
+    void *dbeta1;
+    void *dgamma1;
     void *dcolscale;
 
 };
@@ -152,8 +172,8 @@ using FunctionKey = uint64_t;
 using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
 using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
 
-extern FwdRegistry FWD_FUNCS;
-extern BwdRegistry BWD_FUNCS;
+extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
+extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
@@ -238,4 +258,24 @@ struct BwdRegistrar{
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
+struct FwdParallelRegistrar{
+    FwdParallelRegistrar(FwdFunction f){
+        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
+        PARALLEL_FWD_FUNCS.insert({ key, f });
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
+struct BwdParallelRegistrar{
+    BwdParallelRegistrar(BwdFunction f){
+        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
+        PARALLEL_BWD_FUNCS.insert({ key, f });
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
 }  // namespace layer_norm

+ 408 - 32
csrc/layer_norm/ln_api.cpp

@@ -28,8 +28,8 @@ namespace layer_norm {
 
 // Create registries and provide runtime versions of config hash functions.
 
-FwdRegistry FWD_FUNCS;
-BwdRegistry BWD_FUNCS;
+FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
+BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
@@ -80,6 +80,28 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
+    auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
+    if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
+        return iter->second;
+    } else {
+        TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
+    auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
+    if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
+        return iter->second;
+    } else {
+        TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
 std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input: BxSxhidden_size
                                            c10::optional<const at::Tensor> &residual_,  // Residual: BxSxhidden_size
                                            const at::Tensor &gamma,   // hidden_size
@@ -105,8 +127,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     auto ctype = torch::kFloat32;
     auto mtype = torch::kUInt8;
 
-    TORCH_CHECK(x0.is_cuda())
-    TORCH_CHECK(gamma.is_cuda())
+    TORCH_CHECK(x0.is_cuda());
+    TORCH_CHECK(gamma.is_cuda());
 
     TORCH_CHECK(x0.is_contiguous());
     // c10::IntArrayRef does not own the storage, so we need to construct a vector.
@@ -120,25 +142,26 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     const int rows = sizes[0];
     const int cols = sizes[1];
     auto hidden_size = gamma.numel();
+    TORCH_CHECK(hidden_size == cols);
 
     if (beta_.has_value()) {
         auto beta = beta_.value();
         TORCH_CHECK(beta.dtype() == wtype);
-        TORCH_CHECK(beta.is_cuda())
+        TORCH_CHECK(beta.is_cuda());
         TORCH_CHECK(beta.is_contiguous());
-        TORCH_CHECK(gamma.sizes() == beta.sizes());
+        TORCH_CHECK(beta.sizes() == gamma.sizes());
     }
 
     if (residual_.has_value()) {
         auto residual = residual_.value();
-        TORCH_CHECK(residual.is_cuda())
+        TORCH_CHECK(residual.is_cuda());
         TORCH_CHECK(residual.is_contiguous());
         TORCH_CHECK(residual.sizes() == sizes);
     }
 
     if (rowscale_.has_value()) {
         auto rowscale = rowscale_.value();
-        TORCH_CHECK(rowscale.is_cuda())
+        TORCH_CHECK(rowscale.is_cuda());
         TORCH_CHECK(rowscale.is_contiguous());
         TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
         TORCH_CHECK(rowscale.dtype() == itype);
@@ -146,7 +169,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 
     if (colscale_.has_value()) {
         auto colscale = colscale_.value();
-        TORCH_CHECK(colscale.is_cuda())
+        TORCH_CHECK(colscale.is_cuda());
         TORCH_CHECK(colscale.is_contiguous());
         TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
         TORCH_CHECK(colscale.dtype() == wtype);
@@ -154,7 +177,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 
     if (x0_subset_.has_value()) {
         auto x0_subset = x0_subset_.value();
-        TORCH_CHECK(x0_subset.is_cuda())
+        TORCH_CHECK(x0_subset.is_cuda());
         TORCH_CHECK(x0_subset.is_contiguous());
         TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
         TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
@@ -167,9 +190,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
         TORCH_CHECK(z_subset.dtype() == torch::kInt32);
     }
 
-    TORCH_CHECK(hidden_size == cols);
-    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
-
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
     TORCH_CHECK(epsilon >= 0.f);
 
     // Otherwise the kernel will be launched from cuda:0 device
@@ -306,6 +327,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     auto cols = sizes[1];
     TORCH_CHECK(dz.dim() == 2);
     TORCH_CHECK(dz.size(1) == cols);
+    auto hidden_size = gamma.numel();
+    TORCH_CHECK(hidden_size == cols);
 
     // c10::IntArrayRef does not own the storage, so we need to construct a vector.
     // Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
@@ -316,7 +339,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     if (dx_.has_value()) {
         auto dx = dx_.value();
         TORCH_CHECK(dx.dtype() == rtype);
-        TORCH_CHECK(dx.is_cuda())
+        TORCH_CHECK(dx.is_cuda());
         TORCH_CHECK(dx.is_contiguous());
         TORCH_CHECK(dx.sizes() == sizes);
     }
@@ -331,7 +354,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
 
     if (rowscale_.has_value()) {
         auto rowscale = rowscale_.value();
-        TORCH_CHECK(rowscale.is_cuda())
+        TORCH_CHECK(rowscale.is_cuda());
         TORCH_CHECK(rowscale.is_contiguous());
         TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
         TORCH_CHECK(rowscale.dtype() == itype);
@@ -339,14 +362,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
 
     if (colscale_.has_value()) {
         auto colscale = colscale_.value();
-        TORCH_CHECK(colscale.is_cuda())
+        TORCH_CHECK(colscale.is_cuda());
         TORCH_CHECK(colscale.is_contiguous());
         TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
         TORCH_CHECK(colscale.dtype() == wtype);
 
         TORCH_CHECK(x0_.has_value());
         auto x0 = x0_.value();
-        TORCH_CHECK(x0.is_cuda())
+        TORCH_CHECK(x0.is_cuda());
         TORCH_CHECK(x0.is_contiguous());
         TORCH_CHECK(x0.sizes() == x0_sizes);
         TORCH_CHECK(x0.dtype() == itype);
@@ -354,7 +377,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
 
     if (x0_subset_.has_value()) {
         auto x0_subset = x0_subset_.value();
-        TORCH_CHECK(x0_subset.is_cuda())
+        TORCH_CHECK(x0_subset.is_cuda());
         TORCH_CHECK(x0_subset.is_contiguous());
         TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
         TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
@@ -367,9 +390,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
         TORCH_CHECK(z_subset.dtype() == torch::kInt32);
     }
 
-    auto hidden_size = gamma.numel();
-    TORCH_CHECK(hidden_size == cols);
-    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
 
     TORCH_CHECK(mu.numel() == rows);
     TORCH_CHECK(mu.sizes() == rsigma.sizes());
@@ -457,18 +478,373 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     }
     return result;
 }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
+    const at::Tensor &x0,      // Input: BxSxhidden_size
+    c10::optional<const at::Tensor> &x1_,      // Input: BxSxhidden_size
+    c10::optional<const at::Tensor> &residual_,  // Residual: BxSxhidden_size
+    const at::Tensor &gamma0,   // hidden_size
+    c10::optional<const at::Tensor> &beta0_,   // hidden_size
+    c10::optional<const at::Tensor> &gamma1_,   // hidden_size
+    c10::optional<const at::Tensor> &beta1_,   // hidden_size
+    const float dropout_p,
+    const float epsilon,
+    c10::optional<at::Generator> gen_,
+    bool residual_in_fp32=false,
+    bool is_rms_norm=false
+) {
+    auto itype = x0.scalar_type();
+    auto rtype = residual_.has_value()
+        ? residual_.value().scalar_type()
+        : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
+    auto wtype = gamma0.scalar_type();
+    auto otype = itype;
+    auto ctype = torch::kFloat32;
+    auto mtype = torch::kUInt8;
+
+    TORCH_CHECK(x0.is_cuda());
+    TORCH_CHECK(gamma0.is_cuda());
+
+    TORCH_CHECK(x0.is_contiguous());
+    const auto sizes = x0.sizes();
+    TORCH_CHECK(x0.dim() == 2);
+
+    const int rows = sizes[0];
+    const int cols = sizes[1];
+    auto hidden_size = gamma0.numel();
+    TORCH_CHECK(hidden_size == cols);
+
+    if (x1_.has_value()) {
+        auto x1 = x1_.value();
+        TORCH_CHECK(x1.is_cuda());
+        TORCH_CHECK(x1.is_contiguous());
+        TORCH_CHECK(x1.sizes() == sizes);
+    }
+
+    if (residual_.has_value()) {
+        auto residual = residual_.value();
+        TORCH_CHECK(residual.is_cuda());
+        TORCH_CHECK(residual.is_contiguous());
+        TORCH_CHECK(residual.sizes() == sizes);
+    }
+
+    if (beta0_.has_value()) {
+        auto beta0 = beta0_.value();
+        TORCH_CHECK(beta0.dtype() == wtype);
+        TORCH_CHECK(beta0.is_cuda());
+        TORCH_CHECK(beta0.is_contiguous());
+        TORCH_CHECK(beta0.sizes() == gamma0.sizes());
+    }
+
+    if (gamma1_.has_value()) {
+        auto gamma1 = gamma1_.value();
+        TORCH_CHECK(gamma1.dtype() == wtype);
+        TORCH_CHECK(gamma1.is_cuda());
+        TORCH_CHECK(gamma1.is_contiguous());
+        TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
+    }
+
+    if (beta1_.has_value()) {
+        auto beta1 = beta1_.value();
+        TORCH_CHECK(beta1.dtype() == wtype);
+        TORCH_CHECK(beta1.is_cuda());
+        TORCH_CHECK(beta1.is_contiguous());
+        TORCH_CHECK(beta1.sizes() == gamma0.sizes());
+    }
+
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
+    TORCH_CHECK(epsilon >= 0.f);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
+
+    auto opts = x0.options();
+
+    bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
+    at::Tensor x;
+    if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
+    at::Tensor dmask0, dmask1;
+    if (dropout_p > 0.f) {
+        dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
+        if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
+    };
+    auto z0 = torch::empty(sizes, opts.dtype(otype));
+    at::Tensor z1;
+    if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
+
+    auto mu = torch::empty({ rows }, opts.dtype(ctype));
+    auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
+
+    layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
+
+    launch_params.props = at::cuda::getCurrentDeviceProperties();
+    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
+    TORCH_CHECK(dropout_p < 1.f);
+    launch_params.params.dropout_keep_p = 1.f - dropout_p;
+    launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
+
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
+    // Request the kernel launcher.
+    auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
+
+    // Query the kernel-specific launch parameters.
+    launcher(launch_params, true);
+
+    at::Tensor workspace, barrier;
+
+    // Set the kernel runtime parameters.
+    layer_norm::FwdParams &params = launch_params.params;
+    params.rows = rows;
+    params.cols = cols;
+    params.x0 = x0.data_ptr();
+    params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
+    params.x = save_x ? x.data_ptr() : nullptr;
+    params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
+    params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
+    params.mu = mu.data_ptr();
+    params.rs = rsigma.data_ptr();
+    params.gamma = gamma0.data_ptr();
+    params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
+    params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
+    params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
+    params.z = z0.data_ptr();
+    params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
+    params.epsilon = epsilon;
+    params.dropout_scale = 1.f / (1.f - dropout_p);
+    params.inverse_cols = 1.f / float(params.cols);
+    params.is_rms_norm = is_rms_norm;
+
+    if (dropout_p > 0.f) {
+        // number of times random will be generated per thread, to offset philox counter in thc random
+        // state
+        int64_t counter_offset = 2 * launch_params.elts_per_thread;
+
+        // See Note [Acquire lock when using random generators]
+        {
+            std::lock_guard<std::mutex> lock(gen->mutex_);
+            params.philox_args = gen->philox_cuda_state(counter_offset);
+        }
+    }
+
+    if( launch_params.barrier_size > 0 ) {
+        auto options = x0.options();
+        barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
+        workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
+        params.workspace = workspace.data_ptr();
+        params.barrier = barrier.data_ptr<int>();
+    }
+
+    // Launch the kernel.
+    launcher(launch_params, false);
+
+    return { z0, z1, x, dmask0, dmask1, mu, rsigma };
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
+    const at::Tensor &dz0,     // BxSxhidden_size
+    c10::optional<const at::Tensor> &dz1_,     // BxSxhidden_size
+    c10::optional<const at::Tensor> &dx_,     // BxSxhidden_size
+    const at::Tensor &x,      // BxSxhidden_size
+    c10::optional<const at::Tensor> &dmask0_,  // BxSxhidden_size
+    c10::optional<const at::Tensor> &dmask1_,  // BxSxhidden_size
+    const at::Tensor &mu,     // BxS, FP32!
+    const at::Tensor &rsigma, // BxS, FP32!
+    const at::Tensor &gamma0,   // hidden_size
+    c10::optional<const at::Tensor> &gamma1_,   // hidden_size
+    const float dropout_p,
+    const bool has_x1,
+    const bool has_residual,
+    bool is_rms_norm=false
+) {
+
+    auto itype = dz0.scalar_type();
+    auto rtype = x.scalar_type();
+    auto wtype = gamma0.scalar_type();
+    auto otype = itype;
+    auto ctype = torch::kFloat32;
+    auto mtype = torch::kUInt8;
+
+    if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
+
+    TORCH_CHECK(dz0.dtype() == otype);
+    TORCH_CHECK(dz0.dtype() == otype);
+    TORCH_CHECK(mu.dtype() == ctype);
+    TORCH_CHECK(rsigma.dtype() == ctype);
+
+    TORCH_CHECK(x.is_cuda());
+    TORCH_CHECK(dz0.is_cuda());
+    TORCH_CHECK(mu.is_cuda());
+    TORCH_CHECK(rsigma.is_cuda());
+    TORCH_CHECK(gamma0.is_cuda());
+
+    TORCH_CHECK(x.is_contiguous());
+    TORCH_CHECK(dz0.is_contiguous());
+
+    auto sizes = x.sizes();
+    TORCH_CHECK(sizes.size() == 2);
+    auto rows = sizes[0];
+    auto cols = sizes[1];
+    TORCH_CHECK(dz0.dim() == 2);
+    TORCH_CHECK(dz0.size(1) == cols);
+    auto hidden_size = gamma0.numel();
+    TORCH_CHECK(hidden_size == cols);
+
+    if (dz1_.has_value()) {
+        auto dz1 = dz1_.value();
+        TORCH_CHECK(dz1.dtype() == otype);
+        TORCH_CHECK(dz1.is_cuda());
+        TORCH_CHECK(dz1.is_contiguous());
+        TORCH_CHECK(dz1.sizes() == sizes);
+
+        TORCH_CHECK(gamma1_.has_value());
+        auto gamma1 = gamma1_.value();
+        TORCH_CHECK(gamma1.dtype() == wtype);
+        TORCH_CHECK(gamma1.is_cuda());
+        TORCH_CHECK(gamma1.is_contiguous());
+        TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
+    }
+
+    if (dx_.has_value()) {
+        auto dx = dx_.value();
+        TORCH_CHECK(dx.dtype() == rtype);
+        TORCH_CHECK(dx.is_cuda());
+        TORCH_CHECK(dx.is_contiguous());
+        TORCH_CHECK(dx.sizes() == sizes);
+    }
+
+    if (dmask0_.has_value()) {
+        auto dmask0 = dmask0_.value();
+        TORCH_CHECK(dmask0.dtype() == mtype);
+        TORCH_CHECK(dmask0.is_cuda());
+        TORCH_CHECK(dmask0.is_contiguous());
+        TORCH_CHECK(dmask0.sizes() == sizes);
+
+        if (has_x1) {
+            TORCH_CHECK(dmask1_.has_value());
+            auto dmask1 = dmask1_.value();
+            TORCH_CHECK(dmask1.dtype() == mtype);
+            TORCH_CHECK(dmask1.is_cuda());
+            TORCH_CHECK(dmask1.is_contiguous());
+            TORCH_CHECK(dmask1.sizes() == sizes);
+        }
+    }
+
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
+
+    TORCH_CHECK(mu.numel() == rows);
+    TORCH_CHECK(mu.sizes() == rsigma.sizes());
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
+
+    auto opts = x.options();
+
+    auto dx0 = torch::empty(sizes, opts.dtype(itype));
+    at::Tensor dx1;
+    if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
+    at::Tensor dresidual;
+    if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
+    auto dgamma0 = torch::empty_like(gamma0);
+    auto dbeta0 = torch::empty_like(gamma0);
+    at::Tensor dgamma1, dbeta1;
+    if (gamma1_.has_value()) {
+        dgamma1 = torch::empty_like(gamma0);
+        dbeta1 = torch::empty_like(gamma0);
+    }
+
+    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
+    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
+    launch_params.props = at::cuda::getCurrentDeviceProperties();
+    TORCH_CHECK(dropout_p < 1.f);
+    launch_params.params.dropout_keep_p = 1.f - dropout_p;
+    launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
+    auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
+
+    launcher(launch_params, true);
+
+    auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    at::Tensor dgamma1_part, dbeta1_part;
+    if (gamma1_.has_value()) {
+        dgamma1_part = torch::zeros_like(dgamma0_part);
+        dbeta1_part = torch::zeros_like(dbeta0_part);
+    }
+    at::Tensor workspace, barrier;
+
+    layer_norm::BwdParams &params = launch_params.params;
+    params.rows = rows;
+    params.cols = cols;
+    params.x = x.data_ptr();
+    params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
+    params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
+    params.mu = mu.data_ptr();
+    params.rs = rsigma.data_ptr();
+    params.gamma = gamma0.data_ptr();
+    params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
+    params.dz = dz0.data_ptr();
+    params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
+    params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
+    params.dx0 = dx0.data_ptr();
+    params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
+    params.dbeta = dbeta0.data_ptr();
+    params.dgamma = dgamma0.data_ptr();
+    params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
+    params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
+    params.dbeta_part = dbeta0_part.data_ptr();
+    params.dgamma_part = dgamma0_part.data_ptr();
+    params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
+    params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
+    params.dropout_scale = 1.f / (1.f - dropout_p);
+    params.inverse_cols = 1.f / float(params.cols);
+    params.is_rms_norm = is_rms_norm;
+
+    if( launch_params.barrier_size > 0 ) {
+        // TODO Any way to avoid this?
+        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
+        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
+        params.workspace = workspace.data_ptr();
+        params.barrier = barrier.data_ptr<int>();
+    }
+
+    launcher(launch_params, false);
+
+    std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.doc() = "CUDA DropoutAddLayerNorm";
-  m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
-        py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
-        py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
-        py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
-        py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
-  m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
-        py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
-        py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
-        py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
-        py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
+    m.doc() = "CUDA DropoutAddLayerNorm";
+    m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
+          py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"),
+          py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
+          py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
+          py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
+    m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
+          py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
+          py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
+          py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
+          py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
+    m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel",
+          py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"),
+          py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"),
+          py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
+    m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel",
+          py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"),
+          py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"),
+          py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
 }

+ 15 - 0
csrc/layer_norm/ln_bwd_7168.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_8192.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_fwd_7168.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_8192.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_1024.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_1280.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_1536.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_2048.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_256.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_2560.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_3072.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 17 - 0
csrc/layer_norm/ln_parallel_bwd_4096.cu

@@ -0,0 +1,17 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+// Use 8 warps otherwise there's a lot of register spilling
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_512.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 17 - 0
csrc/layer_norm/ln_parallel_bwd_5120.cu

@@ -0,0 +1,17 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+// Use 8 warps otherwise there's a lot of register spilling
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_6144.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_7168.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_768.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_bwd_8192.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_1024.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_1280.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_1536.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_2048.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_256.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_2560.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_3072.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_4096.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_512.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_5120.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_6144.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_7168.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_768.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_parallel_fwd_8192.cu

@@ -0,0 +1,15 @@
+#include "ln_parallel_residual_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);

+ 540 - 0
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh

@@ -0,0 +1,540 @@
+#pragma once
+
+#include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "static_switch.h"
+#include "ln_bwd_kernels.cuh"
+
+namespace layer_norm {
+
+template<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>
+__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
+void ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) {
+
+    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
+    enum { WARPS_M = Ktraits::WARPS_M };
+    enum { WARPS_N = Ktraits::WARPS_N };
+    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
+    enum { COLS = Ktraits::COLS };
+    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
+    enum { LDGS = Ktraits::LDGS };
+    enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
+    enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
+    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
+
+    using input_t = typename Ktraits::input_t;
+    using compute_t = typename Ktraits::compute_t;
+    using index_t = typename Ktraits::index_t;
+    using mask_t = typename Ktraits::mask_t;
+    using Ivec = typename Ktraits::Ivec;
+    using Rvec = typename Ktraits::Rvec;
+    using Ovec = typename Ktraits::Ovec;
+    using Wvec = typename Ktraits::Wvec;
+    using Cvec = typename Ktraits::Cvec;
+    using Mvec = typename Ktraits::Mvec;
+    using Reducer = typename Ktraits::Reducer;
+    using reduce_t = typename Reducer::Type;
+
+    extern __shared__ char smem_[];
+
+    const bool has_residual = params.dresidual != nullptr;
+    const bool has_x1 = params.dx1 != nullptr;
+    const bool prenorm = params.dx != nullptr;
+
+    const index_t tidx = threadIdx.x;
+    const index_t bidn = blockIdx.x % CTAS_PER_ROW;
+    const index_t bidm = blockIdx.x / CTAS_PER_ROW;
+    const index_t lane = tidx % THREADS_PER_WARP;
+    const index_t warp = tidx / THREADS_PER_WARP;
+    const index_t warp_m = warp / Ktraits::WARPS_N;
+    const index_t warp_n = warp % Ktraits::WARPS_N;
+    const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
+
+    const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
+    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
+
+    static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
+
+    Cvec dz0y_sum[LDGS];
+    Cvec dz0_sum[LDGS];
+    Cvec dz1y_sum[LDGS];
+    Cvec dz1_sum[LDGS];
+
+    memset(dz0y_sum, 0, sizeof(dz0y_sum));
+    memset(dz0_sum, 0, sizeof(dz0_sum));
+    if (!Tied_norm) {
+        memset(dz1y_sum, 0, sizeof(dz1y_sum));
+        memset(dz1_sum, 0, sizeof(dz1_sum));
+    }
+
+    compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
+    char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
+
+    Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
+
+    Sum<reduce_t> sum;
+
+    const index_t num_valid_ldgs =
+        ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
+
+    Wvec gamma0[LDGS];
+    Wvec gamma1[LDGS];
+    index_t idx = c;
+    #pragma unroll
+    for( int it = 0; it < LDGS; it++ ) {
+        if (Is_even_cols || (it < num_valid_ldgs)) {
+            gamma0[it].load_from(params.gamma, idx);
+            if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); }
+            idx += Ktraits::VEC_COLS_PER_LDG;
+        }
+    }
+    // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
+    // last blocks with syncthreads!
+    // grid stride over rows
+    #pragma unroll 1
+    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
+        const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
+        const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
+        Mvec dmask0[LDGS], dmask1[LDGS];
+        Rvec dx[LDGS];
+        compute_t dy[LDGS * NUM_ELTS];
+        compute_t y[LDGS * NUM_ELTS];
+        compute_t mdy_local = 0.f;
+        compute_t mdyy_local = 0.f;
+        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Rvec x;
+                Ovec dz0, dz1;
+                dz0.load_from(params.dz, idx);
+                if (!Tied_norm) { dz1.load_from(params.dz1, idx); }
+                if (prenorm) { dx[it].load_from(params.dx, idx); }
+                x.load_from(params.x, idx);
+                if (Is_dropout) {
+                    dmask0[it].load_from(params.dmask, idx);
+                    if (has_x1) { dmask1[it].load_from(params.dmask1, idx); }
+                }
+                idx += Ktraits::VEC_COLS_PER_LDG;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t x_tmp = x.data.elt[jt];
+                    compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
+                    compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]);
+                    if (!Tied_norm) {
+                        dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]);
+                    }
+                    compute_t dz0_tmp = dz0.data.elt[jt];
+                    compute_t dz1_tmp;
+                    if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; }
+
+                    mdy_local += dy_tmp;
+                    mdyy_local += dy_tmp * y_tmp;
+
+                    dy[it * NUM_ELTS + jt] = dy_tmp;
+                    y[it * NUM_ELTS + jt] = y_tmp;
+
+                    dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp;
+                    dz0_sum[it].data.elt[jt] += dz0_tmp;
+                    if (!Tied_norm) {
+                        dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp;
+                        dz1_sum[it].data.elt[jt] += dz1_tmp;
+                    }
+                }
+            }
+        }
+
+        reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
+        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
+        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
+
+        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ivec dx0, dx1;
+                Rvec dresidual;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t dx_tmp_res;
+                    compute_t dy_tmp = dy[it * NUM_ELTS + jt];
+                    compute_t y_tmp = y[it * NUM_ELTS + jt];
+                    compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
+                    dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
+                    if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
+                    if (Is_dropout) {
+                        dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f;
+                        if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; }
+                    } else {
+                        dx0.data.elt[jt] = dx_tmp_res;
+                        if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; }
+                    }
+                }
+                if (has_residual) { dresidual.store_to(params.dresidual, idx); }
+                dx0.store_to(params.dx0, idx);
+                if (has_x1) { dx1.store_to(params.dx1, idx); }
+                idx += Ktraits::VEC_COLS_PER_LDG;
+            }
+        }
+
+    }  // end: grid stride loop
+
+    if( WARPS_M == 1 ) {
+        idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                dz0_sum[it].store_to(params.dbeta_part, idx);
+                dz0y_sum[it].store_to(params.dgamma_part, idx);
+                if (!Tied_norm) {
+                    dz1_sum[it].store_to(params.dbeta1_part, idx);
+                    dz1y_sum[it].store_to(params.dgamma1_part, idx);
+                }
+                idx += Ktraits::VEC_COLS_PER_LDG;
+            }
+        }
+    } else {
+        static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
+        // Finalize reduction of part dgamma and dbeta for this CTA
+        // by reducing over the rows held across the WARPS_M warps
+
+        // Assumption: blockSize divides hidden size.
+        enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
+        static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
+
+        idx = warp_m * Ktraits::VEC_COLS + tid_r;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            dz0_sum[it].store_to(smem_wgrad, idx);
+            idx += THREADS_PER_ROW;
+        }
+        __syncthreads();
+        compute_t cta_dz0_sum[NUM_RES];
+        memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES);
+        for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+            for( int jt = 0; jt < NUM_RES; jt++ ) {
+                cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+            }
+        }
+        __syncthreads();
+
+        idx = warp_m * Ktraits::VEC_COLS + tid_r;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            dz0y_sum[it].store_to(smem_wgrad, idx);
+            idx += THREADS_PER_ROW;
+        }
+        __syncthreads();
+        compute_t cta_dz0y_sum[NUM_RES];
+        memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES);
+        for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+            for( int jt = 0; jt < NUM_RES; jt++ ) {
+                cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+            }
+        }
+
+        compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES];
+        if (!Tied_norm) {
+            __syncthreads();
+            idx = warp_m * Ktraits::VEC_COLS + tid_r;
+            #pragma unroll
+            for( int it = 0; it < LDGS; it++ ) {
+                dz1_sum[it].store_to(smem_wgrad, idx);
+                idx += THREADS_PER_ROW;
+            }
+            __syncthreads();
+            memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES);
+            for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+                for( int jt = 0; jt < NUM_RES; jt++ ) {
+                    cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+                }
+            }
+            __syncthreads();
+            idx = warp_m * Ktraits::VEC_COLS + tid_r;
+            #pragma unroll
+            for( int it = 0; it < LDGS; it++ ) {
+                dz1y_sum[it].store_to(smem_wgrad, idx);
+                idx += THREADS_PER_ROW;
+            }
+            __syncthreads();
+            memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES);
+            for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+                for( int jt = 0; jt < NUM_RES; jt++ ) {
+                    cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+                }
+            }
+        }
+
+        const index_t num_valid_writes
+            = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
+        compute_t *dgamma0_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
+        compute_t *dbeta0_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
+        compute_t *dgamma1_part = !Tied_norm ? static_cast<compute_t *>(params.dgamma1_part) + bidm * params.cols + tidx : nullptr;
+        compute_t *dbeta1_part = !Tied_norm ? static_cast<compute_t *>(params.dbeta1_part) + bidm * params.cols + tidx : nullptr;
+        for( int jt = 0; jt < NUM_RES; jt++ ) {
+            if (Is_even_cols || (jt < num_valid_writes)) {
+                *dgamma0_part = cta_dz0y_sum[jt];
+                dgamma0_part += Ktraits::THREADS_PER_CTA;
+                *dbeta0_part = cta_dz0_sum[jt];
+                dbeta0_part += Ktraits::THREADS_PER_CTA;
+                if (!Tied_norm) {
+                    *dgamma1_part = cta_dz1y_sum[jt];
+                    dgamma1_part += Ktraits::THREADS_PER_CTA;
+                    *dbeta1_part = cta_dz1_sum[jt];
+                    dbeta1_part += Ktraits::THREADS_PER_CTA;
+                }
+            }
+        }
+
+    }
+}
+
+template<typename Kernel_traits, bool Is_even_cols>
+__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
+void ln_parallel_residual_bwd_finalize_kernel(BwdParams params)
+{
+
+    using compute_t = typename Kernel_traits::compute_t;
+    using weight_t = typename Kernel_traits::weight_t;
+    using index_t = typename Kernel_traits::index_t;
+    using Reducer = typename Kernel_traits::Reducer;
+    using reduce_t = typename Reducer::Type;
+
+    Sum<reduce_t> sum;
+    enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
+    enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
+
+    // Multiplying by 2 since we have both gamma0 and gamma1
+    __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA];
+
+    constexpr uint32_t bidm = 0;
+
+    const uint32_t bidn = blockIdx.x;
+    const uint32_t tidx = threadIdx.x;
+    const uint32_t warp = tidx / THREADS_PER_WARP;
+    const uint32_t lane = tidx % THREADS_PER_WARP;
+
+    Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
+
+    const uint32_t c = bidn * THREADS_PER_WARP + lane;
+    const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
+    constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
+    for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
+        // Each thread sums over NUM_ELT columns.
+        Vec<compute_t, NUM_ELT> dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local;
+        memset(&dgamma0_local, 0, sizeof(dgamma0_local));
+        memset(&dbeta0_local, 0, sizeof(dbeta0_local));
+        memset(&dgamma1_local, 0, sizeof(dgamma1_local));
+        memset(&dbeta1_local, 0, sizeof(dbeta1_local));
+        if (Is_even_cols || col < params.cols) {
+            for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
+                index_t idx = row * params.cols + col;
+
+                Vec<compute_t, NUM_ELT> dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part;
+                dbeta0_part.load_from(params.dbeta_part, idx);
+                dgamma0_part.load_from(params.dgamma_part, idx);
+                dbeta1_part.load_from(params.dbeta1_part, idx);
+                dgamma1_part.load_from(params.dgamma1_part, idx);
+                #pragma unroll
+                for( int it = 0; it < NUM_ELT; it++ ) {
+                    dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it];
+                    dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it];
+                    dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it];
+                    dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it];
+                }
+            }
+        }
+        void * smem_gamma0 = smem_;
+        void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
+
+        const int write_row = warp;
+        const int write_col = lane ^ write_row;
+        const int write_idx = write_row * THREADS_PER_WARP + write_col;
+
+        dgamma0_local.store_to(smem_gamma0, write_idx);
+        dbeta0_local.store_to(smem_beta0, write_idx);
+        dgamma1_local.store_to(smem_gamma1, write_idx);
+        dbeta1_local.store_to(smem_beta1, write_idx);
+
+        __syncthreads();
+
+        // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0
+        void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
+        void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
+        void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT];
+
+        // More than one iter iff ROWS_PER_CTA < 32.
+        for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
+            const int read_row = lane;
+            const int read_col = w ^ read_row;
+            const int read_idx = read_row * THREADS_PER_WARP + read_col;
+
+            memset(&dbeta0_local, 0, sizeof(dbeta0_local));
+            memset(&dgamma0_local, 0, sizeof(dgamma0_local));
+            memset(&dbeta1_local, 0, sizeof(dbeta1_local));
+            memset(&dgamma1_local, 0, sizeof(dgamma1_local));
+
+            // Load beta and gamma transposed
+            if(read_row < Kernel_traits::ROWS_PER_CTA){
+                dbeta0_local.load_from(smem_beta0, read_idx);
+                dgamma0_local.load_from(smem_gamma0, read_idx);
+                dbeta1_local.load_from(smem_beta1, read_idx);
+                dgamma1_local.load_from(smem_gamma1, read_idx);
+            }
+
+            // Call reducer on the loaded value(s) and convert.
+            #pragma unroll
+            for( int it = 0; it < NUM_ELT; it++ ) {
+                compute_t b0_i = dbeta0_local.data.elt[it];
+                compute_t g0_i = dgamma0_local.data.elt[it];
+                compute_t b1_i = dbeta1_local.data.elt[it];
+                compute_t g1_i = dgamma1_local.data.elt[it];
+                b0_i = reducer.allreduce(b0_i, sum);
+                g0_i = reducer.allreduce(g0_i, sum);
+                b1_i = reducer.allreduce(b1_i, sum);
+                g1_i = reducer.allreduce(g1_i, sum);
+
+                dgamma0_local.data.elt[it] = g0_i;
+                dbeta0_local.data.elt[it] = b0_i;
+                dgamma1_local.data.elt[it] = g1_i;
+                dbeta1_local.data.elt[it] = b1_i;
+            }
+
+            // Leader stores the result at the current column.
+            if(lane == 0){
+                dgamma0_local.store_to(smem_gamma0_out, w);
+                dbeta0_local.store_to(smem_beta0_out, w);
+                dgamma1_local.store_to(smem_gamma1_out, w);
+                dbeta1_local.store_to(smem_beta1_out, w);
+            }
+
+        }
+
+        // All writes done.
+        __syncthreads();
+
+        // Pack and store: 2-wide stores with half the threads.
+        if (Is_even_cols || col_out * 2 < params.cols) {
+            if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
+
+                using src_t = typename TypeToVec2<compute_t>::Type;
+                using dst_t = typename TypeToVec2<weight_t>::Type;
+                Vec<src_t, NUM_ELT> dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2;
+                Vec<dst_t, NUM_ELT> dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2;
+
+                dgamma0_vec2.load_from(smem_gamma0_out, lane);
+                dbeta0_vec2.load_from(smem_beta0_out, lane);
+                dgamma1_vec2.load_from(smem_gamma1_out, lane);
+                dbeta1_vec2.load_from(smem_beta1_out, lane);
+                #pragma unroll
+                for( int it = 0; it < NUM_ELT; it++ ) {
+                    dgamma0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma0_vec2.data.elt[it]);
+                    dbeta0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta0_vec2.data.elt[it]);
+                    dgamma1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma1_vec2.data.elt[it]);
+                    dbeta1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta1_vec2.data.elt[it]);
+                }
+                dgamma0_out2.store_to(params.dgamma, col_out);
+                dbeta0_out2.store_to(params.dbeta, col_out);
+                dgamma1_out2.store_to(params.dgamma1, col_out);
+                dbeta1_out2.store_to(params.dbeta1, col_out);
+            }
+        }
+    }
+}
+
+}  // namespace layer_norm
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE,
+    int CTAS_PER_ROW,
+    int WARPS_M,
+    int WARPS_N,
+    int BYTES_PER_LDG_MAIN,
+    int BYTES_PER_LDG_FINAL
+>
+void launch_parallel_residual_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG_MAIN
+                                        >;
+    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
+    bool tied_norm = launch_params.params.gamma1 == nullptr;
+    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
+    BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
+        BOOL_SWITCH(tied_norm, TiedNormConst, [&] {
+            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                auto kernel = &ln_parallel_residual_bwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;
+                if( configure_params ) {
+                    int ctas_per_sm;
+                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
+                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                    launch_params.barrier_size = 0;
+                    launch_params.workspace_bytes = 0;
+                    if(Kernel_traits::CTAS_PER_ROW > 1) {
+                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                      * Kernel_traits::WARPS_M
+                                                      * Kernel_traits::CTAS_PER_ROW
+                                                      * sizeof(typename Kernel_traits::reduce_t)
+                                                      * 2;
+                    }
+                    return;
+                }
+
+                if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
+                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
+                }
+                auto stream = launch_params.stream;
+                auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
+                } else {
+                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                    dim3 block(Kernel_traits::THREADS_PER_CTA);
+                    void *params_ = (void *)&launch_params.params;
+                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
+                }
+
+                using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
+                                                                          weight_t,
+                                                                          input_t,
+                                                                          residual_t,
+                                                                          output_t,
+                                                                          compute_t,
+                                                                          index_t,
+                                                                          /*HasColscaleConst=*/false,
+                                                                          32 * 32,  // THREADS_PER_CTA
+                                                                          BYTES_PER_LDG_FINAL>;
+
+                auto kernel_f = !TiedNormConst
+                    ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>
+                    : &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, /*HasColscaleConst=*/false, IsEvenColsConst>;
+                kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
+
+            });
+        });
+    });
+}

+ 281 - 0
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh

@@ -0,0 +1,281 @@
+#pragma once
+
+#ifdef OLD_GENERATOR_PATH
+#include <ATen/CUDAGeneratorImpl.h>
+#else
+#include <ATen/cuda/CUDAGeneratorImpl.h>
+#endif
+
+#include <ATen/cuda/detail/UnpackRaw.cuh>  // For at::cuda::philox::unpack
+#include <curand_kernel.h>
+
+#include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "static_switch.h"
+
+namespace layer_norm {
+
+template<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>
+__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
+void ln_parallel_residual_fwd_kernel(FwdParams params) {
+
+    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
+    enum { WARPS_N = Ktraits::WARPS_N };
+    enum { WARPS_M = Ktraits::WARPS_M };
+    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
+    enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
+    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
+    enum { LDGS = Ktraits::LDGS };
+    enum { NUM_ELTS = Ktraits::NUM_ELTS };
+    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
+
+    using input_t = typename Ktraits::input_t;
+    using residual_t = typename Ktraits::residual_t;
+    using output_t = typename Ktraits::output_t;
+    using index_t = typename Ktraits::index_t;
+    using compute_t = typename Ktraits::compute_t;
+    using mask_t = typename Ktraits::mask_t;
+    using Ivec = typename Ktraits::Ivec;
+    using Rvec = typename Ktraits::Rvec;
+    using Ovec = typename Ktraits::Ovec;
+    using Wvec = typename Ktraits::Wvec;
+    using Cvec = typename Ktraits::Cvec;
+    using Mvec = typename Ktraits::Mvec;
+
+    using Stats = typename Ktraits::Stats;
+    using stats_t = typename Stats::stats_t;
+
+    const bool has_residual = params.residual != nullptr;
+    const bool has_x1 = params.x1 != nullptr;
+    const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same<input_t, residual_t>::value);
+
+    extern __shared__ char smem_[];
+
+    const index_t tidx = threadIdx.x;
+    const index_t bidn = blockIdx.x % CTAS_PER_ROW;
+    const index_t bidm = blockIdx.x / CTAS_PER_ROW;
+    const index_t lane = tidx % THREADS_PER_WARP;
+    const index_t warp = tidx / THREADS_PER_WARP;
+    const index_t warp_m = warp / WARPS_N;
+    const index_t warp_n = warp % WARPS_N;
+
+    const index_t r = bidm * ROWS_PER_CTA + warp_m;
+    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
+
+    Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
+
+    compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
+    compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
+
+    // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
+    curandStatePhilox4_32_10_t state;
+    if (Is_dropout) {
+        auto seeds = at::cuda::philox::unpack(params.philox_args);
+        const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
+        curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
+    }
+
+    const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
+
+    Wvec gamma0[LDGS];
+    Wvec beta0[LDGS];
+    Wvec gamma1[LDGS];
+    Wvec beta1[LDGS];
+    index_t idx = c;
+    #pragma unroll
+    for( int it = 0; it < LDGS; it++ ) {
+        if (Is_even_cols || (it < num_valid_ldgs)) {
+            gamma0[it].load_from(params.gamma, idx);
+            if (params.beta != nullptr) {
+                beta0[it].load_from(params.beta, idx);
+            } else {
+                beta0[it].zero_();
+            }
+            if (!Tied_norm) {
+                gamma1[it].load_from(params.gamma1, idx);
+                if (params.beta1 != nullptr) {
+                    beta1[it].load_from(params.beta1, idx);
+                } else {
+                    beta1[it].zero_();
+                }
+            }
+            idx += VEC_COLS_PER_LDG;
+        }
+    }
+
+    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
+        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
+        compute_t xf[LDGS * NUM_ELTS];
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ivec x0;
+                Ivec x1;
+                Rvec residual;
+                Rvec x;
+                Mvec dmask0;
+                Mvec dmask1;
+                x0.load_from(params.x0, idx);
+                if (has_x1) { x1.load_from(params.x1, idx); }
+                if (has_residual) { residual.load_from(params.residual, idx); }
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
+                    // the more efficient curand_uniform4.
+                    compute_t x_ij;
+                    mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
+                    if (Is_dropout) { dmask0.data.elt[jt] = keep0; }
+                    compute_t x0_ij = compute_t(x0.data.elt[jt]);
+                    x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
+                    if (has_x1) {
+                        mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
+                        if (Is_dropout) { dmask1.data.elt[jt] = keep1; }
+                        compute_t x1_ij = compute_t(x1.data.elt[jt]);
+                        x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f;
+                        x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij;
+                    } else {
+                        x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
+                    }
+                    if (save_x) { x.data.elt[jt] = x_ij; }
+                    xf[it * NUM_ELTS + jt] = x_ij;
+                }
+                if (save_x) { x.store_to(params.x, idx); }
+                if (Is_dropout) {
+                    dmask0.store_to(params.dmask, idx);
+                    if (has_x1) { dmask1.store_to(params.dmask1, idx); }
+                }
+                idx += VEC_COLS_PER_LDG;
+            }
+        }
+
+        static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
+        const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
+        const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
+        const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
+        auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
+            // Need to convert to int, otherwise the subtraction will wrap around.
+            const index_t valid_partial_vecs_in_warp =
+                std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
+                        int(THREADS_PER_WARP));
+            return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
+        };
+        stats_t s = stats.template compute<Is_even_cols>(
+            xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
+        );
+
+        compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
+        compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
+
+        if( bidn == 0 && warp_n == 0 && lane == 0 ) {
+            mu_ptr[row] = mu;
+        }
+
+        compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
+
+        if( bidn == 0 && warp_n == 0 && lane == 0 ) {
+            rs_ptr[row] = rs;
+        }
+
+        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ovec z0;
+                Ovec z1;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
+                    compute_t g0_ij = gamma0[it].data.elt[jt];
+                    compute_t b0_ij = beta0[it].data.elt[jt];
+                    z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij);
+                    if (!Tied_norm) {
+                        compute_t g1_ij = gamma1[it].data.elt[jt];
+                        compute_t b1_ij = beta1[it].data.elt[jt];
+                        z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij);
+                    }
+                }
+                z0.store_to(params.z, idx);
+                if (!Tied_norm) { z1.store_to(params.z1, idx); }
+                idx += VEC_COLS_PER_LDG;
+            }
+        }
+
+    }
+}
+
+}  // namespace layer_norm
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE,
+    int CTAS_PER_ROW,
+    int WARPS_M,
+    int WARPS_N,
+    int BYTES_PER_LDG
+>
+void launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG
+                                        >;
+    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
+    bool tied_norm = launch_params.params.gamma1 == nullptr;
+    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
+        BOOL_SWITCH(tied_norm, TiedNormConst, [&] {
+            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                auto kernel = &ln_parallel_residual_fwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;
+                if( configure_params ) {
+                    int ctas_per_sm;
+                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
+                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
+                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
+                    launch_params.barrier_size = 0;
+                    launch_params.workspace_bytes = 0;
+                    if(Kernel_traits::CTAS_PER_ROW > 1) {
+                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                      * Kernel_traits::WARPS_M
+                                                      * Kernel_traits::CTAS_PER_ROW
+                                                      * sizeof(typename Kernel_traits::Stats::stats_t)
+                                                      * 2;
+                    }
+                    return;
+                }
+
+                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
+                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
+                }
+                auto stream = launch_params.stream;
+                auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
+                } else {
+                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                    dim3 block(Kernel_traits::THREADS_PER_CTA);
+                    void *params_ = (void *)&launch_params.params;
+                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
+                }
+            });
+        });
+    });
+}

+ 33 - 0
csrc/layer_norm/ln_utils.cuh

@@ -64,6 +64,39 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG)                \
+    void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params,            \
+                                                                                const bool configure_params) {                                       \
+        launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(          \
+            launch_params, configure_params);                                                                                                        \
+    }                                                                                                                                                \
+    static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
+        ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define REGISTER_PARALLEL_BWD_LAUNCHER(                                                                                                              \
+    HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE)                           \
+    void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,            \
+                                                                                const bool configure_params) {                                       \
+        launch_parallel_residual_<WTYPE,                                                                                                             \
+                ITYPE,                                                                                                                               \
+                RTYPE,                                                                                                                               \
+                OTYPE,                                                                                                                               \
+                CTYPE,                                                                                                                               \
+                uint32_t,                                                                                                                            \
+                HIDDEN_SIZE,                                                                                                                         \
+                CTAS_PER_ROW,                                                                                                                        \
+                WARPS_M,                                                                                                                             \
+                WARPS_N,                                                                                                                             \
+                BYTES_PER_LDG,                                                                                                                       \
+                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                                            \
+    }                                                                                                                                                \
+    static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
+        ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
 inline __device__ float2 operator+(const float2 & a, const float2 & b){
     return {a.x + b.x, a.y + b.y};
 }

+ 32 - 0
csrc/layer_norm/setup.py

@@ -139,6 +139,38 @@ ext_modules.append(
             "ln_bwd_5120.cu",
             "ln_fwd_6144.cu",
             "ln_bwd_6144.cu",
+            "ln_fwd_7168.cu",
+            "ln_bwd_7168.cu",
+            "ln_fwd_8192.cu",
+            "ln_bwd_8192.cu",
+            "ln_parallel_fwd_256.cu",
+            "ln_parallel_bwd_256.cu",
+            "ln_parallel_fwd_512.cu",
+            "ln_parallel_bwd_512.cu",
+            "ln_parallel_fwd_768.cu",
+            "ln_parallel_bwd_768.cu",
+            "ln_parallel_fwd_1024.cu",
+            "ln_parallel_bwd_1024.cu",
+            "ln_parallel_fwd_1280.cu",
+            "ln_parallel_bwd_1280.cu",
+            "ln_parallel_fwd_1536.cu",
+            "ln_parallel_bwd_1536.cu",
+            "ln_parallel_fwd_2048.cu",
+            "ln_parallel_bwd_2048.cu",
+            "ln_parallel_fwd_2560.cu",
+            "ln_parallel_bwd_2560.cu",
+            "ln_parallel_fwd_3072.cu",
+            "ln_parallel_bwd_3072.cu",
+            "ln_parallel_fwd_4096.cu",
+            "ln_parallel_bwd_4096.cu",
+            "ln_parallel_fwd_5120.cu",
+            "ln_parallel_bwd_5120.cu",
+            "ln_parallel_fwd_6144.cu",
+            "ln_parallel_bwd_6144.cu",
+            "ln_parallel_fwd_7168.cu",
+            "ln_parallel_bwd_7168.cu",
+            "ln_parallel_fwd_8192.cu",
+            "ln_parallel_bwd_8192.cu",
         ],
         extra_compile_args={
             "cxx": ["-O3"] + generator_flag,

+ 21 - 8
flash_attn/models/gpt.py

@@ -37,6 +37,11 @@ try:
 except ImportError:
     dropout_add_layer_norm = None
 
+try:
+    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
+except ImportError:
+    dropout_add_layer_norm_parallel_residual = None
+
 try:
     from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
 except ImportError:
@@ -282,8 +287,10 @@ class GPTModel(GPTPreTrainedModel):
                                      for i in range(config.num_hidden_layers)])
 
         self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
-        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
-            raise ImportError('dropout_add_layer_norm is not installed')
+        if self.fused_dropout_add_ln:
+            if ((not self.parallel_block and dropout_add_layer_norm is None)
+                or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
+                raise ImportError('dropout_layer_norm is not installed')
         if self.prenorm:
             self.drop_f = nn.Dropout(config.resid_pdrop)
             self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
@@ -340,13 +347,19 @@ class GPTModel(GPTPreTrainedModel):
                                 if residual is not None else dropped + dropped2)
                 hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
             else:
-                assert not self.parallel_block
                 # Set prenorm=False here since we don't need the residual
-                hidden_states = dropout_add_layer_norm(
-                    hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
-                    self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
-                    residual_in_fp32=self.residual_in_fp32
-                )
+                if not self.parallel_block:
+                    hidden_states = dropout_add_layer_norm(
+                        hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
+                        self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
+                        residual_in_fp32=self.residual_in_fp32
+                    )
+                else:
+                    hidden_states, _ = dropout_add_layer_norm_parallel_residual(
+                        hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
+                        None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
+                        prenorm=False, residual_in_fp32=self.residual_in_fp32
+                    )
         return hidden_states
 
 

+ 30 - 15
flash_attn/modules/block.py

@@ -18,6 +18,11 @@ try:
 except ImportError:
     dropout_add_layer_norm = None
 
+try:
+    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
+except ImportError:
+    dropout_add_layer_norm_parallel_residual = None
+
 
 class Block(nn.Module):
 
@@ -64,7 +69,7 @@ class Block(nn.Module):
             self.norm2 = norm_cls(dim)
 
         if self.fused_dropout_add_ln:
-            assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
+            assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
             assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
 
         # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
         super().__init__()
         self.tied_norm = tied_norm
         self.fused_dropout_add_ln = fused_dropout_add_ln
-        assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
         self.residual_in_fp32 = residual_in_fp32
         if mixer_cls is None:
             mixer_cls = partial(MHA, num_heads=dim // 64)
@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
             self.norm2 = norm_cls(dim)
 
         if self.fused_dropout_add_ln:
-            assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
+            assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
             assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
 
         # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
             hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
             residual.
         """
-        dropped1 = self.dropout1(hidden_states1)
-        # For the very 1st block, we only want 1 dropout, not two different dropouts
-        if hidden_states2 is not None:
-            dropped2 = self.dropout2(hidden_states2)
-            residual = ((residual + dropped1 + dropped2)
-                        if residual is not None else dropped1 + dropped2)
+        if not self.fused_dropout_add_ln:
+            dropped1 = self.dropout1(hidden_states1)
+            # For the very 1st block, we only want 1 dropout, not two different dropouts
+            if hidden_states2 is not None:
+                dropped2 = self.dropout2(hidden_states2)
+                residual = ((residual + dropped1 + dropped2)
+                            if residual is not None else dropped1 + dropped2)
+            else:
+                residual = (residual + dropped1) if residual is not None else dropped1
+            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
+            hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
+                              if not self.tied_norm else hidden_states1)
+            if self.residual_in_fp32:
+                residual = residual.to(torch.float32)
         else:
-            residual = (residual + dropped1) if residual is not None else dropped1
-        hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
-        hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
-                          if not self.tied_norm else hidden_states1)
-        if self.residual_in_fp32:
-            residual = residual.to(torch.float32)
+            weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
+                              if not self.tied_norm else (None, None))
+            hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual(
+                hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
+                weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
+                prenorm=True, residual_in_fp32=self.residual_in_fp32
+            )
+            if self.tied_norm:
+                hidden_states2 = hidden_states1
         if mixer_kwargs is None:
             mixer_kwargs = {}
         hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)

+ 109 - 2
flash_attn/ops/layer_norm.py

@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
         return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
 
 
+def _dropout_add_layer_norm_parallel_residual_forward(
+    x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
+    epsilon, residual_in_fp32=False, is_rms_norm=False
+):
+    """ Assume that arguments are contiguous
+    """
+    hidden_size = gamma0.numel()
+    x0mat = x0.view((-1, hidden_size))
+    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
+    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
+    z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
+        x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
+        None, residual_in_fp32, is_rms_norm
+    )
+    # dmask0 and dmask1 are None if dropout_p == 0.0
+    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
+    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
+
+
+def _dropout_add_layer_norm_parallel_residual_backward(
+    dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
+    dropout_p, has_x1, has_residual, is_rms_norm=False
+):
+    """ Assume that arguments are contiguous
+    dx == None means that it was a post-norm architecture
+    (x = drop(x0) + residual was not returned in the fwd).
+    """
+    hidden_size = gamma0.numel()
+    xmat = x.view((-1, hidden_size))
+    dz0mat = dz0.view(xmat.shape)
+    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
+    dxmat = dx.view(xmat.shape) if dx is not None else None
+    dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
+        dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
+        dropout_p, has_x1, has_residual, is_rms_norm
+    )
+    # dresidualmat is None if not has_residual
+    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
+
+
 class DropoutAddLayerNormFn(torch.autograd.Function):
     @staticmethod
     def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
         )
         # Only need to save x0 if we need to compute gradient wrt colscale
         x0_saved = x0 if colscale is not None else None
-        ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
+        ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
         ctx.prenorm = prenorm
         ctx.dropout_p = dropout_p
         ctx.has_residual = residual is not None
@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
         # Only need to save x0 if we need to compute gradient wrt colscale
         x0_saved = x0 if colscale is not None else None
         x_shape = (-1, *x0.shape[1:])
-        ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
+        ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
                               x0_subset, out_subset)
         ctx.prenorm = prenorm
         ctx.dropout_p = dropout_p
@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
                 None, None, None, None, None, None, None, None)
 
 
+class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
+                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
+        x0 = x0.contiguous()
+        x1 = x1.contiguous() if x1 is not None else None
+        residual = residual.contiguous() if residual is not None else None
+        gamma0 = gamma0.contiguous()
+        beta0 = beta0.contiguous() if beta0 is not None else None
+        gamma1 = gamma1.contiguous() if gamma1 is not None else None
+        beta1 = beta1.contiguous() if beta1 is not None else None
+        z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
+            x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
+            residual_in_fp32, is_rms_norm
+        )
+        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
+        ctx.prenorm = prenorm
+        ctx.dropout_p = dropout_p
+        ctx.has_x1 = x1 is not None
+        ctx.has_residual = residual is not None
+        ctx.is_rms_norm = is_rms_norm
+        ctx.has_beta = beta0 is not None
+        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
+        if not return_dmask:
+            return z if not prenorm else (*z, xmat.view(x0.shape))
+        else:
+            dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
+                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
+            dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
+                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
+            ctx.mark_non_differentiable(dmask0)
+            ctx.mark_non_differentiable(dmask1)
+            return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
+
+    @staticmethod
+    def backward(ctx, dz0, dz1, *args):
+        dz0 = dz0.contiguous()  # this happens!
+        dz1 = dz1.contiguous() if dz1 is not None else None
+        dx = args[0].contiguous() if ctx.prenorm else None
+        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
+        dropout_p = ctx.dropout_p
+        has_x1 = ctx.has_x1
+        has_residual = ctx.has_residual
+        dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
+            dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
+            has_residual, ctx.is_rms_norm
+        )
+        dx0 = dx0mat.view(x.shape)
+        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
+        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
+        return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
+                dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
+
+
 def layer_norm(x, weight, bias, epsilon):
     return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
 
@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
     )
 
 
+def dropout_add_layer_norm_parallel_residual(
+    x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
+    residual_in_fp32=False, return_dropout_mask=False
+):
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
+    """
+    return DropoutAddLayerNormParallelResidualFn.apply(
+        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
+        False, return_dropout_mask
+    )
+
+
 class DropoutAddLayerNorm(torch.nn.Module):
     def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                  device=None, dtype=None):

+ 14 - 0
flash_attn/ops/rms_norm.py

@@ -5,6 +5,7 @@ import torch
 from torch.nn import init
 
 from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
+from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
 
 
 def rms_norm(x, weight, epsilon):
@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
     )
 
 
+def dropout_add_rms_norm_parallel_residual(
+   x0, x1, residual, weight0, bias0, weight1, bias1,
+   dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
+):
+    """residual_in_fp32 only has an effect if residual is None.
+    Otherwise residual dtype is residual.dtype.
+    """
+    return DropoutAddLayerNormParallelResidualFn.apply(
+        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
+        True, return_dropout_mask
+    )
+
+
 class DropoutAddRMSNorm(torch.nn.Module):
     def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                  device=None, dtype=None):

+ 1 - 1
tests/models/test_gpt_neox.py

@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
     config.use_flash_attn = True
     config.fused_bias_fc = True
     config.fused_mlp = True  # GPT-NeoX-20B uses "gelu_fast"
-    config.fused_dropout_add_ln = False  # We don't support parallel block yet
+    config.fused_dropout_add_ln = True
     config.residual_in_fp32 = True
 
     model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)

+ 1 - 1
tests/models/test_gptj.py

@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
     config.use_flash_attn = False  # FlashAttention doesn't support hdim 256 yet
     config.fused_bias_fc = True
     config.fused_mlp = True
-    config.fused_dropout_add_ln = False  # We don't support parallel block yet
+    config.fused_dropout_add_ln = True
     config.residual_in_fp32 = True
 
     model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)

+ 344 - 48
tests/ops/test_dropout_layer_norm.py

@@ -10,11 +10,14 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
 from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
 from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
 from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
+from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
+from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
 
 try:
     from apex.normalization import FusedRMSNorm
+    from apex.normalization.fused_layer_norm import fused_rms_norm_affine
 except:
-    FusedRMSNorm = None
+    FusedRMSNorm, fused_rms_norm_affine = None, None
 
 
 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@@ -35,8 +38,8 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
                           (torch.float32, torch.float32)]
                          + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
 # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
-# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
-@pytest.mark.parametrize('hidden_size', [256])
+@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
+# @pytest.mark.parametrize('hidden_size', [256])
 def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
                                      dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
     if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
@@ -64,11 +67,11 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
     else:
         colscale = None
     if has_residual:
-        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
-        x1 = x1_pt.detach().clone().requires_grad_()
-        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
     else:
-        x1 = None
+        res = None
     if has_rowscale:
         rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
         survival_rate = 0.87
@@ -95,14 +98,14 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
             model.bias.copy_(model_pt.bias)
             model_ref.bias.copy_(model_pt.bias)
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
-    out, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
+    out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
                                      model.epsilon, rowscale=rowscale, layerscale=colscale,
                                      residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
     assert out.dtype == input_dtype
     print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
     if has_residual:
-        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
-        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
     else:
         residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
         residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
@@ -116,8 +119,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
     out_ref.backward(g)
     assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
     if has_residual:
-        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
-    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
+    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
     if not is_rms_norm:
         assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
     if has_colscale:
@@ -145,9 +148,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
                         requires_grad=True)
     x0 = x0_pt.detach().clone().requires_grad_()
     x0_ref = x0_pt.detach().clone().float().requires_grad_()
-    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
-    x1 = x1_pt.detach().clone().requires_grad_()
-    x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+    res = res_pt.detach().clone().requires_grad_()
+    res_ref = res_pt.detach().clone().float().requires_grad_()
     model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
     torch.nn.init.normal_(model_pt.weight)
     torch.nn.init.normal_(model_pt.bias)
@@ -161,9 +164,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
     model_pt.eval()
     model.eval()
     model_ref.eval()
-    out = model(x0, x1)
-    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
-    residual_ref = x0_ref + x1_ref
+    out = model(x0, res)
+    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
+    residual_ref = x0_ref + res_ref
     out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
     out_ref = model_ref(residual_ref)
     assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@@ -215,11 +218,11 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
     else:
         colscale = None
     if has_residual:
-        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
-        x1 = x1_pt.detach().clone().requires_grad_()
-        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
     else:
-        x1 = None
+        res = None
     if has_rowscale:
         rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
         survival_rate = 0.87
@@ -247,15 +250,15 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
             model.bias.copy_(model_pt.bias)
             model_ref.bias.copy_(model_pt.bias)
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
-    out, residual, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
+    out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
                                                model.epsilon, rowscale=rowscale,
                                                layerscale=colscale, prenorm=True,
                                                residual_in_fp32=residual_in_fp32,
                                                return_dropout_mask=True)
     print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
     if has_residual:
-        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
-        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
     else:
         residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
         residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
@@ -272,7 +275,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
     (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
     assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
     if has_residual:
-        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
     assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
     if not is_rms_norm:
         assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
@@ -301,9 +304,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
                         requires_grad=True)
     x0 = x0_pt.detach().clone().requires_grad_()
     x0_ref = x0_pt.detach().clone().float().requires_grad_()
-    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
-    x1 = x1_pt.detach().clone().requires_grad_()
-    x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+    res = res_pt.detach().clone().requires_grad_()
+    res_ref = res_pt.detach().clone().float().requires_grad_()
     model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
     torch.nn.init.normal_(model_pt.weight)
     torch.nn.init.normal_(model_pt.bias)
@@ -318,9 +321,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
     model_pt.eval()
     model.eval()
     model_ref.eval()
-    out, residual = model(x0, x1)
-    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
-    residual_ref = x0_ref + x1_ref
+    out, residual = model(x0, res)
+    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
+    residual_ref = x0_ref + res_ref
     out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
     out_ref = model_ref(residual_ref)
     assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@@ -382,11 +385,11 @@ def test_dropout_layer_norm_subset_training(
     else:
         colscale = None
     if has_residual:
-        x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
-        x1 = x1_pt.detach().clone().requires_grad_()
-        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
     else:
-        x1 = None
+        res = None
 
     if has_colscale:
         x0_scaled_pt = x0_pt * colscale_pt
@@ -409,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
 
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
     out, dmask = dropout_add_layer_norm_subset(
-        x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
+        x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
         x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
         out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
         return_dropout_mask=True)
@@ -424,8 +427,8 @@ def test_dropout_layer_norm_subset_training(
     dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
     dmask_expanded[x0_mask_batch] = dmask
     if has_residual:
-        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
-        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
+        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
     else:
         residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
         residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
@@ -440,7 +443,7 @@ def test_dropout_layer_norm_subset_training(
     out_ref.backward(g)
     assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
     if has_residual:
-        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
     assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
     assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
     if has_colscale:
@@ -502,11 +505,11 @@ def test_dropout_layer_norm_subset_prenorm_training(
     else:
         colscale = None
     if has_residual:
-        x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
-        x1 = x1_pt.detach().clone().requires_grad_()
-        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
     else:
-        x1 = None
+        res = None
 
     if has_colscale:
         x0_scaled_pt = x0_pt * colscale_pt
@@ -529,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
 
     residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
     out, residual, dmask = dropout_add_layer_norm_subset(
-        x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
+        x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
         x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
         out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
         return_dropout_mask=True)
@@ -544,8 +547,8 @@ def test_dropout_layer_norm_subset_prenorm_training(
     dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
     dmask_expanded[x0_mask_batch] = dmask
     if has_residual:
-        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
-        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
+        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
     else:
         residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
         residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
@@ -562,8 +565,301 @@ def test_dropout_layer_norm_subset_prenorm_training(
     (out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
     assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
     if has_residual:
-        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
     assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
     assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
     if has_colscale:
         assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
+
+
+@pytest.mark.parametrize('is_rms_norm', [False, True])
+# @pytest.mark.parametrize('is_rms_norm', [False])
+@pytest.mark.parametrize('tied_norm', [False, True])
+# @pytest.mark.parametrize('tied_norm', [False])
+@pytest.mark.parametrize('has_residual', [True, False])
+# @pytest.mark.parametrize('has_residual', [False])
+@pytest.mark.parametrize('has_x1', [True, False])
+# @pytest.mark.parametrize('has_x1', [True])
+@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
+# @pytest.mark.parametrize('dropout_p', [0.0])
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+# @pytest.mark.parametrize('weight_dtype', [torch.float16])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
+@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
+# @pytest.mark.parametrize('hidden_size', [256])
+def test_dropout_layer_norm_parallel_residual_training(
+    hidden_size, input_dtype, residual_dtype, weight_dtype,
+    dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
+):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    if is_rms_norm and fused_rms_norm_affine is None:
+        pytest.skip()  # We need Apex's FusedRMSNorm to test
+    our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
+                           else dropout_add_rms_norm_parallel_residual)
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 1e-4)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_x1:
+        x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                            requires_grad=True)
+        x1 = x1_pt.detach().clone().requires_grad_()
+        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    else:
+        x1 = None
+    if has_residual:
+        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
+    else:
+        res = None
+    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+    bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+             if not is_rms_norm else None)
+    weight0_pt = weight0.detach().clone().requires_grad_()
+    weight0_ref = weight0.detach().clone().float().requires_grad_()
+    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
+    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
+    if not tied_norm:
+        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+        bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+                 if not is_rms_norm else None)
+        weight1_pt = weight1.detach().clone().requires_grad_()
+        weight1_ref = weight1.detach().clone().float().requires_grad_()
+        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
+        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
+    else:
+        weight1, bias1 = None, None
+    epsilon = 1e-5
+    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
+
+    out0, out1, dmask0, dmask1 = our_layer_norm_func(
+        x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
+        epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
+    )
+    assert out0.dtype == input_dtype
+    if not tied_norm:
+        assert out1.dtype == input_dtype
+    print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
+    if has_residual:
+        if has_x1:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+                          + res_pt.float()).to(dtype=residual_dtype)
+            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+                            + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
+        else:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + res_pt.float()).to(dtype=residual_dtype)
+            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
+    else:
+        if has_x1:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+                            + (x1_ref * dmask1.float()) / (1 - dropout_p))
+        else:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
+    if not is_rms_norm:
+        out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
+                               eps=epsilon).to(dtype=input_dtype)
+        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
+        if not tied_norm:
+            out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
+                                   bias1_pt, eps=epsilon).to(dtype=input_dtype)
+            out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
+    else:
+        out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
+                                        eps=epsilon).to(dtype=input_dtype)
+        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
+        if not tied_norm:
+            out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
+                                            (hidden_size,), eps=epsilon).to(dtype=input_dtype)
+            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
+
+    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
+    if not tied_norm:
+        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
+
+    g0 = torch.randn_like(out0) / batch_size
+    if tied_norm:
+        out0.backward(g0)
+        out0_pt.backward(g0)
+        out0_ref.backward(g0)
+    else:
+        g1 = torch.randn_like(out1) / batch_size
+        (out0 * g0 + out1 * g1).sum().backward()
+        (out0_pt * g0 + out1_pt * g1).sum().backward()
+        (out0_ref * g0 + out1_ref * g1).sum().backward()
+    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
+    if has_x1:
+        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+    if has_residual:
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
+    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
+    if not is_rms_norm:
+        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
+    if not tied_norm:
+        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
+        if not is_rms_norm:
+            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
+
+
+@pytest.mark.parametrize('is_rms_norm', [False, True])
+# @pytest.mark.parametrize('is_rms_norm', [False])
+@pytest.mark.parametrize('tied_norm', [False, True])
+# @pytest.mark.parametrize('tied_norm', [False])
+@pytest.mark.parametrize('has_residual', [True, False])
+# @pytest.mark.parametrize('has_residual', [False])
+@pytest.mark.parametrize('has_x1', [True, False])
+# @pytest.mark.parametrize('has_x1', [True])
+@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
+# @pytest.mark.parametrize('dropout_p', [0.0])
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+# @pytest.mark.parametrize('weight_dtype', [torch.float16])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
+@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
+# @pytest.mark.parametrize('hidden_size', [256])
+def test_dropout_layer_norm_parallel_residual_prenorm_training(
+    hidden_size, input_dtype, residual_dtype, weight_dtype,
+    dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
+):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    if is_rms_norm and fused_rms_norm_affine is None:
+        pytest.skip()  # We need Apex's FusedRMSNorm to test
+    our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
+                           else dropout_add_rms_norm_parallel_residual)
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 1e-4)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_x1:
+        x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                            requires_grad=True)
+        x1 = x1_pt.detach().clone().requires_grad_()
+        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    else:
+        x1 = None
+    if has_residual:
+        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        res = res_pt.detach().clone().requires_grad_()
+        res_ref = res_pt.detach().clone().float().requires_grad_()
+    else:
+        res = None
+    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+    bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+             if not is_rms_norm else None)
+    weight0_pt = weight0.detach().clone().requires_grad_()
+    weight0_ref = weight0.detach().clone().float().requires_grad_()
+    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
+    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
+    if not tied_norm:
+        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+        bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+                 if not is_rms_norm else None)
+        weight1_pt = weight1.detach().clone().requires_grad_()
+        weight1_ref = weight1.detach().clone().float().requires_grad_()
+        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
+        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
+    else:
+        weight1, bias1 = None, None
+    epsilon = 1e-5
+    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
+
+    out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
+        x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
+        epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
+    )
+    assert out0.dtype == input_dtype
+    if not tied_norm:
+        assert out1.dtype == input_dtype
+    print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
+    if has_residual:
+        if has_x1:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+                          + res_pt.float()).to(dtype=residual_dtype)
+            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+                            + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
+        else:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + res_pt.float()).to(dtype=residual_dtype)
+            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
+    else:
+        if has_x1:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+                            + (x1_ref * dmask1.float()) / (1 - dropout_p))
+        else:
+            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
+    if not is_rms_norm:
+        out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
+                               eps=epsilon).to(dtype=input_dtype)
+        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
+        if not tied_norm:
+            out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
+                                   bias1_pt, eps=epsilon).to(dtype=input_dtype)
+            out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
+    else:
+        out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
+                                        eps=epsilon).to(dtype=input_dtype)
+        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
+        if not tied_norm:
+            out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
+                                            (hidden_size,), eps=epsilon).to(dtype=input_dtype)
+            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
+
+    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
+    if not tied_norm:
+        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
+    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
+
+    g0 = torch.randn_like(out0) / batch_size
+    if tied_norm:
+        (out0 * F.sigmoid(residual)).backward(g0)
+        (out0_pt * F.sigmoid(residual_pt)).backward(g0)
+        (out0_ref * F.sigmoid(residual_ref)).backward(g0)
+    else:
+        g1 = torch.randn_like(out1) / batch_size
+        (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
+        (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
+        (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
+    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
+    if has_x1:
+        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+    if has_residual:
+        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
+    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
+    if not is_rms_norm:
+        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
+    if not tied_norm:
+        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
+        if not is_rms_norm:
+            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5