Browse Source

[LayerNorm] Support all dimensions up to 6k (if divisible by 8)

Tri Dao 2 years ago
parent
commit
8c6609ae1a

+ 4 - 3
csrc/layer_norm/README.md

@@ -1,10 +1,11 @@
-This CUDA extension implements fused dropout + residual + LayerNorm, based on
+This CUDA extension implements fused dropout + residual + LayerNorm, building on
 Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
 We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
+We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
 
-This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
+If you want to use it for dimensions larger than 6k, please file an issue.
 
-It has only been tested on A100s.
+This extension has only been tested on A100s.
 
 ```sh
 cd csrc/layer_norm && pip install .

+ 2 - 0
csrc/layer_norm/ln.h

@@ -64,6 +64,8 @@ struct ParamsBase {
     void *gamma;
     void *rowscale;
 
+    float inverse_cols;
+
     float dropout_keep_p;
     float dropout_scale;
 

+ 17 - 4
csrc/layer_norm/ln_api.cpp

@@ -129,6 +129,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
 
     TORCH_CHECK(gamma.sizes() == beta.sizes());
     TORCH_CHECK(hidden_size == cols);
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
 
     TORCH_CHECK(epsilon >= 0.f);
 
@@ -156,8 +157,10 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
         gen_, at::cuda::detail::getDefaultCUDAGenerator());
 
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
     // Request the kernel launcher.
-    auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+    auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
     // Query the kernel-specific launch parameters.
     launcher(launch_params, true);
@@ -178,6 +181,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     params.z = z.data_ptr();
     params.epsilon = epsilon;
     params.dropout_scale = 1.f / (1.f - dropout_p);
+    params.inverse_cols = 1.f / float(params.cols);
 
     if (dropout_p > 0.f) {
         // number of times random will be generated per thread, to offset philox counter in thc random
@@ -263,6 +267,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     }
 
     auto hidden_size = gamma.numel();
+    TORCH_CHECK(hidden_size == cols);
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
 
     TORCH_CHECK(mu.numel() == rows);
     TORCH_CHECK(mu.sizes() == rsigma.sizes());
@@ -285,7 +291,9 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
 
-    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
+    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
     launcher(launch_params, true, /*prenorm=*/false);
 
@@ -308,6 +316,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidd
     params.dbeta_part = dbeta_part.data_ptr();
     params.dgamma_part = dgamma_part.data_ptr();
     params.dropout_scale = 1.f / (1.f - dropout_p);
+    params.inverse_cols = 1.f / float(params.cols);
 
     if( launch_params.barrier_size > 0 ) {
         // TODO Any way to avoid this?
@@ -385,6 +394,8 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
     }
 
     auto hidden_size = gamma.numel();
+    TORCH_CHECK(hidden_size == cols);
+    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
 
     TORCH_CHECK(mu.numel() == rows);
     TORCH_CHECK(mu.sizes() == rsigma.sizes());
@@ -407,8 +418,9 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
     launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
     launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
 
-    // TODO: how to set template param for launcher
-    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
+    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
     launcher(launch_params, true, /*prenorm=*/true);
 
@@ -432,6 +444,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     //
     params.dbeta_part = dbeta_part.data_ptr();
     params.dgamma_part = dgamma_part.data_ptr();
     params.dropout_scale = 1.f / (1.f - dropout_p);
+    params.inverse_cols = 1.f / float(params.cols);
 
     if( launch_params.barrier_size > 0 ) {
         // TODO Any way to avoid this?

+ 15 - 0
csrc/layer_norm/ln_bwd_1024.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_1280.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_1536.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_2048.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_256.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_2560.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_3072.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_4096.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_512.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_5120.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_6144.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_bwd_768.cu

@@ -0,0 +1,15 @@
+#include "ln_bwd_kernels.cuh"
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);

+ 198 - 88
csrc/layer_norm/ln_bwd_kernels.cuh

@@ -1,8 +1,13 @@
 #pragma once
 
+#include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "static_switch.h"
+
 namespace layer_norm {
 
-template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale>
+template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
 __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
 void ln_bwd_kernel(layer_norm::BwdParams params) {
 
@@ -59,13 +64,17 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
 
     Sum<reduce_t> sum;
 
-    constexpr float rn = 1.f / float(COLS);
+    const index_t num_valid_ldgs =
+        ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
+
     Wvec gamma[LDGS];
     index_t idx = c;
     #pragma unroll
     for( int it = 0; it < LDGS; it++ ) {
-        gamma[it].load_from(params.gamma, idx);
-        idx += Ktraits::VEC_COLS_PER_LDG;
+        if (Is_even_cols || (it < num_valid_ldgs)) {
+            gamma[it].load_from(params.gamma, idx);
+            idx += Ktraits::VEC_COLS_PER_LDG;
+        }
     }
     // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
     // last blocks with syncthreads!
@@ -74,79 +83,85 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
     for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
         const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
         const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
-        const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f;
+        const compute_t rowscale_val =
+            params.rowscale == nullptr ? 1.0f : compute_t(static_cast<const input_t *>(params.rowscale)[row]);
         Mvec dmask[LDGS];
         Rvec dx[LDGS];
         compute_t dy[LDGS * NUM_ELTS];
         compute_t y[LDGS * NUM_ELTS];
         compute_t mdy_local = 0.f;
         compute_t mdyy_local = 0.f;
-        index_t idx = row * Ktraits::VEC_COLS + c;
+        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
         #pragma unroll
         for( int it = 0; it < LDGS; it++ ) {
-            Rvec x;
-            Ovec dz;
-            dz.load_from(params.dz, idx);
-            if (Prenorm) { dx[it].load_from(params.dx, idx); }
-            x.load_from(params.x, idx);
-            if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
-            idx += Ktraits::VEC_COLS_PER_LDG;
-            #pragma unroll
-            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
-                compute_t x_tmp = x.data.elt[jt];
-                compute_t y_tmp = rs_r * (x_tmp - mu_r);
-                compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
-                dy_tmp *= compute_t(dz.data.elt[jt]);
-                compute_t dz_tmp = dz.data.elt[jt];
-
-                mdy_local += dy_tmp;
-                mdyy_local += dy_tmp * y_tmp;
-
-                dy[it * NUM_ELTS + jt] = dy_tmp;
-                y[it * NUM_ELTS + jt] = y_tmp;
-
-                dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
-                dz_sum[it].data.elt[jt] += dz_tmp;
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Rvec x;
+                Ovec dz;
+                dz.load_from(params.dz, idx);
+                if (Prenorm) { dx[it].load_from(params.dx, idx); }
+                x.load_from(params.x, idx);
+                if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
+                idx += Ktraits::VEC_COLS_PER_LDG;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t x_tmp = x.data.elt[jt];
+                    compute_t y_tmp = rs_r * (x_tmp - mu_r);
+                    compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
+                    compute_t dz_tmp = dz.data.elt[jt];
+
+                    mdy_local += dy_tmp;
+                    mdyy_local += dy_tmp * y_tmp;
+
+                    dy[it * NUM_ELTS + jt] = dy_tmp;
+                    y[it * NUM_ELTS + jt] = y_tmp;
+
+                    dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
+                    dz_sum[it].data.elt[jt] += dz_tmp;
+                }
             }
         }
 
         reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
-        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
-        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
+        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
+        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
 
-        idx = row * Ktraits::VEC_COLS + c;
+        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
         #pragma unroll
         for( int it = 0; it < LDGS; it++ ) {
-            Ivec dx0;
-            Rvec dx1;
-            #pragma unroll
-            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
-                compute_t dy_tmp = dy[it * NUM_ELTS + jt];
-                compute_t y_tmp = y[it * NUM_ELTS + jt];
-                compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
-                compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
-                if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
-                compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res;
-                if (Is_dropout) {
-                    dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
-                } else {
-                    dx0.data.elt[jt] = dx0_tmp_res;
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ivec dx0;
+                Rvec dx1;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t dy_tmp = dy[it * NUM_ELTS + jt];
+                    compute_t y_tmp = y[it * NUM_ELTS + jt];
+                    compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
+                    compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
+                    if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
+                    compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
+                    if (Is_dropout) {
+                        dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
+                    } else {
+                        dx0.data.elt[jt] = dx0_tmp_res;
+                    }
                 }
+                if (Has_residual) { dx1.store_to(params.dx1, idx); }
+                dx0.store_to(params.dx0, idx);
+                idx += Ktraits::VEC_COLS_PER_LDG;
             }
-            if (Has_residual) { dx1.store_to(params.dx1, idx); }
-            dx0.store_to(params.dx0, idx);
-            idx += Ktraits::VEC_COLS_PER_LDG;
         }
 
     }  // end: grid stride loop
 
     if( WARPS_M == 1 ) {
-        idx = r * Ktraits::VEC_COLS + c;
+        idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
         #pragma unroll
         for( int it = 0; it < LDGS; it++ ) {
-            dz_sum[it].store_to(params.dbeta_part, idx);
-            dzy_sum[it].store_to(params.dgamma_part, idx);
-            idx += Ktraits::VEC_COLS_PER_LDG;
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                dz_sum[it].store_to(params.dbeta_part, idx);
+                dzy_sum[it].store_to(params.dgamma_part, idx);
+                idx += Ktraits::VEC_COLS_PER_LDG;
+            }
         }
     } else {
         static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
@@ -188,21 +203,23 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
             }
         }
 
-        compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
+        const index_t num_valid_writes
+            = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
+        compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
+        compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
         for( int jt = 0; jt < NUM_RES; jt++ ) {
-            *dgamma_part = cta_dzy_sum[jt];
-            dgamma_part += Ktraits::THREADS_PER_CTA;
+            if (Is_even_cols || (jt < num_valid_writes)) {
+                *dgamma_part = cta_dzy_sum[jt];
+                dgamma_part += Ktraits::THREADS_PER_CTA;
+                *dbeta_part = cta_dz_sum[jt];
+                dbeta_part += Ktraits::THREADS_PER_CTA;
+            }
         }
 
-        compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
-        for( int jt = 0; jt < NUM_RES; jt++ ) {
-            *dbeta_part = cta_dz_sum[jt];
-            dbeta_part += Ktraits::THREADS_PER_CTA;
-        }
     }
 }
 
-template<typename Kernel_traits>
+template<typename Kernel_traits, bool Is_even_cols>
 __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
 void ln_bwd_finalize_kernel(BwdParams params)
 {
@@ -236,19 +253,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
         Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
         memset(&dgamma_local, 0, sizeof(dgamma_local));
         memset(&dbeta_local, 0, sizeof(dbeta_local));
-        for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
-            index_t idx = row * Kernel_traits::COLS + col;
-
-            Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
-            dbeta_part.load_from(params.dbeta_part, idx);
-            dgamma_part.load_from(params.dgamma_part, idx);
-            #pragma unroll
-            for( int it = 0; it < NUM_ELT; it++ ) {
-                dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
-                dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
+        if (Is_even_cols || col < params.cols) {
+            for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
+                // index_t idx = row * Kernel_traits::COLS + col;
+                index_t idx = row * params.cols + col;
+
+                Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
+                dbeta_part.load_from(params.dbeta_part, idx);
+                dgamma_part.load_from(params.dgamma_part, idx);
+                #pragma unroll
+                for( int it = 0; it < NUM_ELT; it++ ) {
+                    dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
+                    dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
+                }
             }
         }
-
         void * smem_gamma = smem_;
         void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
 
@@ -305,24 +324,115 @@ void ln_bwd_finalize_kernel(BwdParams params)
         __syncthreads();
 
         // Pack and store: 2-wide stores with half the threads.
-        if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
-
-            using src_t = typename TypeToVec2<compute_t>::Type;
-            using dst_t = typename TypeToVec2<weight_t>::Type;
-            Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
-            Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
+        if (Is_even_cols || col_out * 2 < params.cols) {
+            if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
+
+                using src_t = typename TypeToVec2<compute_t>::Type;
+                using dst_t = typename TypeToVec2<weight_t>::Type;
+                Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
+                Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
+
+                dgamma_vec2.load_from(smem_gamma_out, lane);
+                dbeta_vec2.load_from(smem_beta_out, lane);
+                #pragma unroll
+                for( int it = 0; it < NUM_ELT; it++ ) {
+                    dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
+                    dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
+                }
+                dgamma_out2.store_to(params.dgamma, col_out);
+                dbeta_out2.store_to(params.dbeta, col_out);
 
-            dgamma_vec2.load_from(smem_gamma_out, lane);
-            dbeta_vec2.load_from(smem_beta_out, lane);
-            #pragma unroll
-            for( int it = 0; it < NUM_ELT; it++ ) {
-                dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
-                dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
             }
-            dgamma_out2.store_to(params.dgamma, col_out);
-            dbeta_out2.store_to(params.dbeta, col_out);
-
         }
     }
 }
 }  // namespace layer_norm
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE,
+    int CTAS_PER_ROW,
+    int WARPS_M,
+    int WARPS_N,
+    int BYTES_PER_LDG_MAIN,
+    int BYTES_PER_LDG_FINAL
+>
+void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG_MAIN
+                                        >;
+    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
+    bool has_residual = launch_params.params.dx1 != nullptr;
+    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
+    BOOL_SWITCH(prenorm, PrenormConst, [&] {
+        BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
+            BOOL_SWITCH(has_residual, HasResidualConst, [&] {
+                BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                    auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
+                    if( configure_params ) {
+                        int ctas_per_sm;
+                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
+                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                        launch_params.barrier_size = 0;
+                        launch_params.workspace_bytes = 0;
+                        if(Kernel_traits::CTAS_PER_ROW > 1) {
+                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                            launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                          * Kernel_traits::WARPS_M
+                                                          * Kernel_traits::CTAS_PER_ROW
+                                                          * sizeof(typename Kernel_traits::reduce_t)
+                                                          * 2;
+                        }
+                        return;
+                    }
+
+                    if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
+                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
+                    }
+                    auto stream = launch_params.stream;
+                    auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
+                    } else {
+                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                        dim3 block(Kernel_traits::THREADS_PER_CTA);
+                        void *params_ = (void *)&launch_params.params;
+                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
+                    }
+
+                    using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
+                                                                              weight_t,
+                                                                              input_t,
+                                                                              residual_t,
+                                                                              output_t,
+                                                                              compute_t,
+                                                                              index_t,
+                                                                              32 * 32,  // THREADS_PER_CTA
+                                                                              BYTES_PER_LDG_FINAL>;
+
+                    auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
+                    kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
+                });
+            });
+        });
+    });
+}

+ 0 - 325
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu

@@ -1,325 +0,0 @@
-#include "ln.h"
-#include "ln_utils.cuh"
-#include "ln_kernel_traits.h"
-#include "ln_bwd_kernels.cuh"
-#include "static_switch.h"
-
-using namespace layer_norm;
-
-template<
-    typename weight_t,
-    typename input_t,
-    typename residual_t,
-    typename output_t,
-    typename compute_t,
-    typename index_t,
-    int HIDDEN_SIZE, 
-    int CTAS_PER_ROW, 
-    int WARPS_M, 
-    int WARPS_N, 
-    int BYTES_PER_LDG_MAIN,
-    int BYTES_PER_LDG_FINAL
->
-void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
-
-    using Kernel_traits = Kernel_traits<weight_t,
-                                        input_t,
-                                        residual_t,
-                                        output_t,
-                                        compute_t,
-                                        index_t,
-                                        HIDDEN_SIZE,
-                                        CTAS_PER_ROW,
-                                        WARPS_M,
-                                        WARPS_N,
-                                        BYTES_PER_LDG_MAIN
-                                        >;
-    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
-    bool has_residual = launch_params.params.dx1 != nullptr;
-    bool has_rowscale = launch_params.params.rowscale != nullptr;
-    BOOL_SWITCH(prenorm, PrenormConst, [&] {
-        BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
-            BOOL_SWITCH(has_residual, HasResidualConst, [&] {
-                BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
-                    auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
-                    if( configure_params ) {
-                        int ctas_per_sm;
-                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
-                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
-                        launch_params.barrier_size = 0;
-                        launch_params.workspace_bytes = 0;
-                        if(Kernel_traits::CTAS_PER_ROW > 1) {
-                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
-                            launch_params.workspace_bytes = launch_params.params.ctas_per_col
-                                                          * Kernel_traits::WARPS_M
-                                                          * Kernel_traits::CTAS_PER_ROW
-                                                          * sizeof(typename Kernel_traits::reduce_t)
-                                                          * 2;
-                        }
-                        return;
-                    }
-
-                    if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
-                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
-                    }
-                    auto stream = launch_params.stream;
-                    auto ctas_per_col = launch_params.params.ctas_per_col;
-
-                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {
-                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
-                    } else {
-                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
-                        dim3 block(Kernel_traits::THREADS_PER_CTA);
-                        void *params_ = (void *)&launch_params.params;
-                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
-                    }
-
-                    using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
-                                                                              weight_t,
-                                                                              input_t,
-                                                                              residual_t,
-                                                                              output_t,
-                                                                              compute_t,
-                                                                              index_t,
-                                                                              32 * 32,  // THREADS_PER_CTA
-                                                                              BYTES_PER_LDG_FINAL>;
-
-                    auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
-                    kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
-                });
-            });
-        });
-    });
-}
-
-// Create backward launch function and register. Macro signature:
-//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
-
-REGISTER_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
-REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-
-REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1,  4, 4);
-REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1,  4, 4);
-
-REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-
-REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-
-REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
-REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
-
-// TD [2022-04-22] Disable most of these to speed up compile time
-
-// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
-
-// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4,  4, 4);
-// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4,  4, 4);
-// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
-
-// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
-// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
-// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
-// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
-// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
-
-// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
-
-// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4,  4, 4);
-// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4,  4, 4);
-// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4,  8, 4);
-
-// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4,  8, 4);
-// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
-// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
-// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
-// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
-// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
-// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
-
-// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
-
-// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
-// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);

+ 15 - 0
csrc/layer_norm/ln_fwd_1024.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_1280.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_1536.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_2048.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_256.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_2560.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_3072.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_4096.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_512.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_5120.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_6144.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
+REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);

+ 15 - 0
csrc/layer_norm/ln_fwd_768.cu

@@ -0,0 +1,15 @@
+#include "ln_fwd_kernels.cuh"
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

+ 0 - 302
csrc/layer_norm/ln_fwd_cuda_kernel.cu

@@ -1,302 +0,0 @@
-#include "ln.h"
-#include "ln_utils.cuh"
-#include "ln_kernel_traits.h"
-#include "ln_fwd_kernels.cuh"
-#include "static_switch.h"
-
-using namespace layer_norm;
-
-template<
-    typename weight_t,
-    typename input_t,
-    typename residual_t,
-    typename output_t,
-    typename compute_t,
-    typename index_t,
-    int HIDDEN_SIZE, 
-    int CTAS_PER_ROW, 
-    int WARPS_M, 
-    int WARPS_N, 
-    int BYTES_PER_LDG
->
-void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
-
-    using Kernel_traits = Kernel_traits<weight_t,
-                                        input_t,
-                                        residual_t,
-                                        output_t,
-                                        compute_t,
-                                        index_t,
-                                        HIDDEN_SIZE,
-                                        CTAS_PER_ROW,
-                                        WARPS_M,
-                                        WARPS_N,
-                                        BYTES_PER_LDG
-                                        >;
-    bool has_residual = launch_params.params.x1 != nullptr;
-    bool has_rowscale = launch_params.params.rowscale != nullptr;
-    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
-        BOOL_SWITCH(has_residual, HasResidualConst, [&] {
-            BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
-                auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
-                if( configure_params ) {
-                    int ctas_per_sm;
-                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
-                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
-                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
-                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
-                    launch_params.barrier_size = 0;
-                    launch_params.workspace_bytes = 0;
-                    if(Kernel_traits::CTAS_PER_ROW > 1) {
-                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
-                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
-                                                      * Kernel_traits::WARPS_M
-                                                      * Kernel_traits::CTAS_PER_ROW
-                                                      * sizeof(typename Kernel_traits::Stats::stats_t)
-                                                      * 2;
-                    }
-                    return;
-                }
-
-                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
-                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
-                }
-                auto stream = launch_params.stream;
-                auto ctas_per_col = launch_params.params.ctas_per_col;
-
-                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
-                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
-                } else {
-                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
-                    dim3 block(Kernel_traits::THREADS_PER_CTA);
-                    void *params_ = (void *)&launch_params.params;
-                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
-                }
-            });
-        });
-    });
-}
-
-// Create forward launch function and register. Macro signature:
-//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
-
-REGISTER_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1,  4);
-REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1,  4);
-
-REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-
-REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-
-REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-
-REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-
-// TD [2022-04-22] Disable most of these to speed up compile time
-
-// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
-// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
-// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
-// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
-// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
-
-// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4,  4);
-
-// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4,  4);
-
-// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4,  8);
-
-// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4,  8);
-// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4,  4);
-
-// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4,  4);
-// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4,  4);
-
-// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
-
-// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
-// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);

+ 140 - 48
csrc/layer_norm/ln_fwd_kernels.cuh

@@ -10,10 +10,13 @@
 #include <curand_kernel.h>
 
 #include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "static_switch.h"
 
 namespace layer_norm {
 
-template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_rowscale>
+template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols>
 __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
 void ln_fwd_kernel(FwdParams params) {
 
@@ -73,57 +76,70 @@ void ln_fwd_kernel(FwdParams params) {
         curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
     }
 
+    const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
+
     Wvec gamma[LDGS];
     Wvec beta[LDGS];
     index_t idx = c;
     #pragma unroll
     for( int it = 0; it < LDGS; it++ ) {
-        gamma[it].load_from(params.gamma, idx);
-        beta[it].load_from(params.beta, idx);
-        idx += VEC_COLS_PER_LDG;
+        if (Is_even_cols || (it < num_valid_ldgs)) {
+            gamma[it].load_from(params.gamma, idx);
+            beta[it].load_from(params.beta, idx);
+            idx += VEC_COLS_PER_LDG;
+        }
     }
 
-    constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
-
     for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
-        const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f;
-        index_t idx = row * Ktraits::VEC_COLS + c;
+        const compute_t rowscale_val = params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row]);
+        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
         compute_t xf[LDGS * NUM_ELTS];
         #pragma unroll
         for( int it = 0; it < LDGS; it++ ) {
-            Ivec x0;
-            Rvec x1;
-            Rvec x;
-            Mvec dmask;
-            x0.load_from(params.x0, idx);
-            if (Has_residual) { x1.load_from(params.x1, idx); }
-            #pragma unroll
-            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
-                // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
-                // the more efficient curand_uniform4.
-                mask_t keep = true;
-                if (Is_dropout) {
-                    float rand = curand_uniform(&state);
-                    keep = mask_t(rand <= params.dropout_keep_p);
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ivec x0;
+                Rvec x1;
+                Rvec x;
+                Mvec dmask;
+                x0.load_from(params.x0, idx);
+                if (Has_residual) { x1.load_from(params.x1, idx); }
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
+                    // the more efficient curand_uniform4.
+                    mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
+                    compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
+                    compute_t x_ij;
+                    if (Has_residual) {
+                        compute_t x1_ij = compute_t(x1.data.elt[jt]);
+                        x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
+                    } else  {
+                        x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
+                    }
+                    if (save_x) { x.data.elt[jt] = x_ij; }
+                    xf[it * NUM_ELTS + jt] = x_ij;
+                    if (Is_dropout) { dmask.data.elt[jt] = keep; }
                 }
-                compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]);
-                compute_t x_ij;
-                if (Has_residual) {
-                    compute_t x1_ij = compute_t(x1.data.elt[jt]);
-                    x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
-                } else  {
-                    x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
-                }
-                if (save_x) { x.data.elt[jt] = x_ij; }
-                xf[it * NUM_ELTS + jt] = x_ij;
-                if (Is_dropout) { dmask.data.elt[jt] = keep; }
+                if (save_x) { x.store_to(params.x, idx); }
+                if (Is_dropout) { dmask.store_to(params.dmask, idx); }
+                idx += VEC_COLS_PER_LDG;
             }
-            if (save_x) { x.store_to(params.x, idx); }
-            if (Is_dropout) { dmask.store_to(params.dmask, idx); }
-            idx += VEC_COLS_PER_LDG;
         }
 
-        stats_t s = stats.compute(xf, rn);
+        static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
+        const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
+        const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
+        const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
+        // Need to convert to int, otherwise the subtraction will wrap around.
+        auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
+            const index_t valid_partial_vecs_in_warp =
+                std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
+                        int(THREADS_PER_WARP));
+            return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
+        };
+        stats_t s = stats.template compute<Is_even_cols>(
+            xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
+        );
 
         compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
         compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
@@ -132,28 +148,104 @@ void ln_fwd_kernel(FwdParams params) {
             mu_ptr[row] = mu;
         }
 
-        compute_t rs = rsqrtf(rn * m2 + params.epsilon);
+        compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon);
 
         if( bidn == 0 && warp_n == 0 && lane == 0 ) {
             rs_ptr[row] = rs;
         }
 
-        idx = row * Ktraits::VEC_COLS + c;
+        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
         #pragma unroll
         for( int it = 0; it < LDGS; it++ ) {
-            Ovec z;
-            #pragma unroll
-            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
-                output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
-                output_t g_ij = gamma[it].data.elt[jt];
-                output_t b_ij = beta[it].data.elt[jt];
-                z.data.elt[jt] = (g_ij * y_ij + b_ij);
+            if (Is_even_cols || (it < num_valid_ldgs)) {
+                Ovec z;
+                #pragma unroll
+                for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                    compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu));
+                    compute_t g_ij = gamma[it].data.elt[jt];
+                    compute_t b_ij = beta[it].data.elt[jt];
+                    z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
+                }
+                z.store_to(params.z, idx);
+                idx += VEC_COLS_PER_LDG;
             }
-            z.store_to(params.z, idx);
-            idx += VEC_COLS_PER_LDG;
         }
 
     }
 }
 
 }  // namespace layer_norm
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE,
+    int CTAS_PER_ROW,
+    int WARPS_M,
+    int WARPS_N,
+    int BYTES_PER_LDG
+>
+void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG
+                                        >;
+    bool has_residual = launch_params.params.x1 != nullptr;
+    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
+    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
+        BOOL_SWITCH(has_residual, HasResidualConst, [&] {
+            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
+                auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
+                if( configure_params ) {
+                    int ctas_per_sm;
+                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
+                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
+                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
+                    launch_params.barrier_size = 0;
+                    launch_params.workspace_bytes = 0;
+                    if(Kernel_traits::CTAS_PER_ROW > 1) {
+                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                      * Kernel_traits::WARPS_M
+                                                      * Kernel_traits::CTAS_PER_ROW
+                                                      * sizeof(typename Kernel_traits::Stats::stats_t)
+                                                      * 2;
+                    }
+                    return;
+                }
+
+                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
+                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
+                }
+                auto stream = launch_params.stream;
+                auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
+                } else {
+                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                    dim3 block(Kernel_traits::THREADS_PER_CTA);
+                    void *params_ = (void *)&launch_params.params;
+                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
+                }
+            });
+        });
+    });
+}

+ 27 - 18
csrc/layer_norm/ln_utils.cuh

@@ -530,20 +530,20 @@ struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
  
-template<typename T>
-inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
+template<typename T, typename int_t>
+inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
     //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
-    int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
+    const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
     
     #pragma unroll
     for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
         // Exchange
-        T n_b = warp_shuffle_down(n_a, step);
+        int_t n_b = warp_shuffle_down(n_a, step);
         T m_b = warp_shuffle_down(m_a, step);
         T m2_b = warp_shuffle_down(m2_a, step);
 
         // Update
-        const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
+        const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
         const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
         const T delta = m_a - m_b;
         const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
@@ -647,23 +647,26 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
         smem1_ = smem0_ + WARPS_M * WARPS_N;
     }
 
-    template<uint32_t N>
-    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
+    template<bool Is_even_cols, uint32_t N, typename function_t>
+    inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
+                                      function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
         stats_t * smem = use0_ ? smem0_ : smem1_;
         use0_ = !use0_;
         // Compute warp local for all WARPS_N
-        constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
-        stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
+        const auto warp_n = warp_stats_.reducer_.warp_n_;
+        const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
+        stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
+            elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
+        );
 
         //Each warp warp leader stores its stats
-        const auto warp_n = warp_stats_.reducer_.warp_n_;
         const auto lane = warp_stats_.reducer_.lane_;
         if( lane == 0 ) {
             smem[warp_n] = warp_stats;
         }
         __syncthreads();
 
-        T n = Zeros<T>::get();
+        int n = 0;;
         T m = Zeros<T>::get();
         T m2 = Zeros<T>::get();
 
@@ -671,7 +674,7 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
         static_assert(WARPS_N <= 32);
         if(lane < WARPS_N){
             stats_t result = smem[lane];
-            n = N * THREADS_PER_WARP;
+            n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
             m = layer_norm::Get<0>::of<stats_t, T>(result);
             m2 = layer_norm::Get<1>::of<stats_t, T>(result);
         }
@@ -703,23 +706,29 @@ struct Stats<T, 1, WARPS_M, 1> {
     {
     }
 
-    template<uint32_t N>
-    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
+    template<bool Is_even_cols, uint32_t N, typename function_t>
+    inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
+                                      // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
+                                      function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
 
         auto sum = Sum<T>();
 
         T m = Zeros<T>::get();
         #pragma unroll
         for( int it = 0; it < N; it++ ) {
-            m += elts[it];
+            if (Is_even_cols || (it < num_valid_elts)) {
+                m += elts[it];
+            }
         }
-        m = reducer_.allreduce(m, sum) * rn;
+        m = reducer_.allreduce(m, sum) * row_norm_factor;
 
         T m2 = Zeros<T>::get();
         #pragma unroll
         for( int it = 0; it < N; it++ ) {
-            T diff = (elts[it] - m);
-            m2 += diff * diff;
+            if (Is_even_cols || (it < num_valid_elts)) {
+                T diff = (elts[it] - m);
+                m2 += diff * diff;
+            }
         }
         m2 = reducer_.allreduce(m2, sum);
 

+ 24 - 2
csrc/layer_norm/setup.py

@@ -108,8 +108,30 @@ ext_modules.append(
         name="dropout_layer_norm",
         sources=[
             "ln_api.cpp",
-            "ln_fwd_cuda_kernel.cu",
-            "ln_bwd_semi_cuda_kernel.cu",
+            "ln_fwd_256.cu",
+            "ln_bwd_256.cu",
+            "ln_fwd_512.cu",
+            "ln_bwd_512.cu",
+            "ln_fwd_768.cu",
+            "ln_bwd_768.cu",
+            "ln_fwd_1024.cu",
+            "ln_bwd_1024.cu",
+            "ln_fwd_1280.cu",
+            "ln_bwd_1280.cu",
+            "ln_fwd_1536.cu",
+            "ln_bwd_1536.cu",
+            "ln_fwd_2048.cu",
+            "ln_bwd_2048.cu",
+            "ln_fwd_2560.cu",
+            "ln_bwd_2560.cu",
+            "ln_fwd_3072.cu",
+            "ln_bwd_3072.cu",
+            "ln_fwd_4096.cu",
+            "ln_bwd_4096.cu",
+            "ln_fwd_5120.cu",
+            "ln_bwd_5120.cu",
+            "ln_fwd_6144.cu",
+            "ln_bwd_6144.cu",
         ],
         extra_compile_args={
             "cxx": ["-O3"] + generator_flag,

+ 1 - 2
flash_attn/ops/layer_norm.py

@@ -2,7 +2,6 @@
 import torch
 from torch.nn import init
 
-# from apex._autocast_utils import _cast_if_autocast_enabled
 import dropout_layer_norm
 
 
@@ -145,7 +144,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
 
 
 class DropoutAddLayerNorm(torch.nn.Module):
-    def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False,
+    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                  device=None, dtype=None):
         factory_kwargs = {'device': device, 'dtype': dtype}
         super().__init__()

+ 8 - 3
tests/ops/test_dropout_layer_norm.py

@@ -24,8 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
                           (torch.float32, torch.float32)]
                          + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
 # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
-@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
-# @pytest.mark.parametrize('hidden_size', [768])
+@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
 def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
                                      dropout_p, has_residual, has_rowscale):
     if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
@@ -148,7 +147,13 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
                          [(torch.float16, torch.float16), (torch.float16, torch.float32),
                           (torch.float32, torch.float32)]
                          + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
-@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+# @pytest.mark.parametrize('has_rowscale', [False])
+# @pytest.mark.parametrize('has_residual', [True])
+# @pytest.mark.parametrize('dropout_p', [0.0])
+# @pytest.mark.parametrize('weight_dtype', [torch.float32])
+# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
+# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
 def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
                                              dropout_p, has_residual, has_rowscale):
     if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: