Kaynağa Gözat

Add fused_dense and dropout_add_layernorm CUDA extensions

Tri Dao 2 yıl önce
ebeveyn
işleme
fa6d1ce44f

+ 10 - 0
csrc/fused_dense_lib/README.md

@@ -0,0 +1,10 @@
+This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
+(forward and backward), adapted from Apex's
+[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
+We make it work for bfloat16.
+
+For best performance, you should use CUDA >= 11.8. CuBLAS versions before
+this doesn't have the best matmul + bias + gelu performance for bfloat16.
+```sh
+cd csrc/fused_dense_lib && pip install .
+```

+ 356 - 0
csrc/fused_dense_lib/fused_dense.cpp

@@ -0,0 +1,356 @@
+// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
+// We make it work for bfloat16
+#include <torch/extension.h>
+#include <torch/torch.h>
+#include <vector>
+
+#include <stdio.h>
+
+// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
+// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...)                                \
+  switch (TYPE) {                                                              \
+  case at::ScalarType::Half: {                                                 \
+    using scalar_t = at::Half;                                                 \
+    __VA_ARGS__();                                                             \
+    break;                                                                     \
+  }                                                                            \
+  case at::ScalarType::BFloat16: {                                             \
+    using scalar_t = at::BFloat16;                                             \
+    __VA_ARGS__();                                                             \
+    break;                                                                     \
+  }                                                                            \
+  default:                                                                     \
+    AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");            \
+  }
+
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+
+template <typename T>
+int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
+
+template <typename T>
+int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input,  bool residual, void *lt_workspace);
+
+template <typename T>
+int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
+
+template <typename T>
+int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
+
+template <typename T>
+int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
+
+at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int out_features = weight.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
+    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    auto result = linear_bias_forward_cuda<scalar_t>(
+        input,
+        w_ptr,
+        bias,
+        in_features,
+        batch_size,
+        out_features,
+        out,
+        //out.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_bias_forward failed.")
+  });
+
+  return {out};
+}
+
+std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int out_features = weight.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto d_weight = at::empty({out_features, in_features}, opts);
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
+  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
+#else
+  auto d_bias = at::empty({out_features}, opts);
+#endif
+  auto d_input = at::empty({batch_size, in_features}, opts);
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
+    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    auto result = linear_bias_backward_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        w_ptr,
+        d_output.data_ptr<scalar_t>(),
+        in_features,
+        batch_size,
+        out_features,
+        d_weight.data_ptr<scalar_t>(),
+        d_bias.data_ptr<scalar_t>(),
+        d_input.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        /*residual=*/false,
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_bias_backward failed.")
+  });
+
+  return {d_input, d_weight, d_bias};
+}
+
+std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int out_features = d_output.size(1);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto d_weight = at::empty({out_features, in_features}, opts);
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
+  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
+#else
+  auto d_bias = at::empty({out_features}, opts);
+#endif
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
+    auto result = linear_bias_wgrad_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        d_output.data_ptr<scalar_t>(),
+        in_features,
+        batch_size,
+        out_features,
+        d_weight.data_ptr<scalar_t>(),
+        d_bias.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
+  });
+
+  return {d_weight, d_bias};
+}
+
+std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int out_features = weight.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto d_weight = at::empty({out_features, in_features}, opts);
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
+  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
+#else
+  auto d_bias = at::empty({out_features}, opts);
+#endif
+  CHECK_SHAPE(d_input, batch_size, in_features);
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
+    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    auto result = linear_bias_backward_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        w_ptr,
+        d_output.data_ptr<scalar_t>(),
+        in_features,
+        batch_size,
+        out_features,
+        d_weight.data_ptr<scalar_t>(),
+        d_bias.data_ptr<scalar_t>(),
+        d_input.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        /*residual=*/true,
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
+  });
+
+  return {d_input, d_weight, d_bias};
+}
+
+std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
+                                            bool save_gelu_in, int heuristic) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int out_features = weight.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto output = at::empty({batch_size, out_features}, opts);
+  at::Tensor gelu_in;
+  if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
+    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    scalar_t* b_ptr = bias.data_ptr<scalar_t>();
+    auto result = linear_gelu_forward_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        w_ptr,
+        b_ptr,
+        in_features,
+        batch_size,
+        out_features,
+        heuristic,
+        output.data_ptr<scalar_t>(),
+        save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
+       // reserved_space.data_ptr<scalar_t>(),
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
+  });
+
+  std::vector<at::Tensor> result = {output};
+  if (save_gelu_in) { result.push_back(gelu_in); };
+  return result;
+}
+
+std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int hidden_features = weight1.size(0);
+  int out_features = weight2.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto d_weight1 = at::empty({hidden_features, in_features}, opts);
+  auto d_weight2 = at::empty({out_features, hidden_features}, opts);
+  auto d_bias1 = at::empty({hidden_features}, opts);
+  auto d_bias2 = at::empty({out_features}, opts);
+  auto d_input = at::empty({batch_size, in_features}, opts);
+  auto d_output1 = at::empty({batch_size, hidden_features}, opts);
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
+    //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
+    auto result = linear_gelu_linear_backward_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        gelu_in.data_ptr<scalar_t>(),
+        output1.data_ptr<scalar_t>(),
+        weight1.data_ptr<scalar_t>(),
+        weight2.data_ptr<scalar_t>(),
+        d_output1.data_ptr<scalar_t>(),
+        d_output2.data_ptr<scalar_t>(),
+        in_features,
+        batch_size,
+        hidden_features,
+        out_features,
+        heuristic,
+        d_weight1.data_ptr<scalar_t>(),
+        d_weight2.data_ptr<scalar_t>(),
+        d_bias1.data_ptr<scalar_t>(),
+        d_bias2.data_ptr<scalar_t>(),
+        d_input.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        /*residual=*/false,
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
+  });
+
+  return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
+}
+
+std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
+
+  auto batch_size = input.size(0);
+  auto in_features = input.size(1);
+
+  int hidden_features = weight1.size(0);
+  int out_features = weight2.size(0);
+
+  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+
+  // create output/workspace tensor
+  auto opts = input.options();
+  auto d_weight1 = at::empty({hidden_features, in_features}, opts);
+  auto d_weight2 = at::empty({out_features, hidden_features}, opts);
+  auto d_bias1 = at::empty({hidden_features}, opts);
+  auto d_bias2 = at::empty({out_features}, opts);
+  CHECK_SHAPE(d_input, batch_size, in_features);
+  auto d_output1 = at::empty({batch_size, hidden_features}, opts);
+  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
+  auto lt_workspace = at::empty({1 << 22}, opts);
+
+  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
+    //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
+    //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
+    auto result = linear_gelu_linear_backward_cuda<scalar_t>(
+        input.data_ptr<scalar_t>(),
+        gelu_in.data_ptr<scalar_t>(),
+        output1.data_ptr<scalar_t>(),
+        weight1.data_ptr<scalar_t>(),
+        weight2.data_ptr<scalar_t>(),
+        d_output1.data_ptr<scalar_t>(),
+        d_output2.data_ptr<scalar_t>(),
+        in_features,
+        batch_size,
+        hidden_features,
+        out_features,
+        heuristic,
+        d_weight1.data_ptr<scalar_t>(),
+        d_weight2.data_ptr<scalar_t>(),
+        d_bias1.data_ptr<scalar_t>(),
+        d_bias2.data_ptr<scalar_t>(),
+        d_input.data_ptr<scalar_t>(),
+       // reserved_space.data_ptr<scalar_t>(),
+        /*residual=*/true,
+        (void*) (lt_workspace.data_ptr<scalar_t>()));
+    TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
+  });
+
+  return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
+  m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
+  m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
+  m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
+  m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
+  m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
+  m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
+}

+ 1336 - 0
csrc/fused_dense_lib/fused_dense_cuda.cu

@@ -0,0 +1,1336 @@
+// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <assert.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <torch/torch.h>
+
+/* Includes, cuda */
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
+// includes cublaslt
+#include <cublasLt.h>
+#endif
+
+// FP16 Tensor core wrapper around cublas GEMMEx
+cublasStatus_t gemm_bias(
+    cublasHandle_t handle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float* alpha,
+    at::Half* A,
+    int lda,
+    at::Half* B,
+    int ldb,
+    const float* beta,
+    at::Half* C,
+    int ldc) {
+  return cublasGemmEx(
+      handle,
+      transa,
+      transb,
+      m,
+      n,
+      k,
+      alpha,
+      A,
+      CUDA_R_16F,
+      lda,
+      B,
+      CUDA_R_16F,
+      ldb,
+      beta,
+      C,
+      CUDA_R_16F,
+      ldc,
+      CUDA_R_32F,
+      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+}
+
+// BF16 Tensor core wrapper around cublas GEMMEx
+cublasStatus_t gemm_bias(
+    cublasHandle_t handle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float* alpha,
+    at::BFloat16* A,
+    int lda,
+    at::BFloat16* B,
+    int ldb,
+    const float* beta,
+    at::BFloat16* C,
+    int ldc) {
+  return cublasGemmEx(
+      handle,
+      transa,
+      transb,
+      m,
+      n,
+      k,
+      alpha,
+      A,
+      CUDA_R_16BF,
+      lda,
+      B,
+      CUDA_R_16BF,
+      ldb,
+      beta,
+      C,
+      CUDA_R_16BF,
+      ldc,
+      CUDA_R_32F,
+      CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+}
+
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+
+int gemm_bias_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::Half* A,
+    int lda,
+    at::Half* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::Half* C,
+    int ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    const void* bias) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  cublasLtMatmulHeuristicResult_t heuristicResult = {};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+      epilogue = CUBLASLT_EPILOGUE_BIAS;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_bias_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::BFloat16* A,
+    int lda,
+    at::BFloat16* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::BFloat16* C,
+    int ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    const void* bias) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  cublasLtMatmulHeuristicResult_t heuristicResult = {};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+      epilogue = CUBLASLT_EPILOGUE_BIAS;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_bias_gelu_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::Half* A,
+    int lda,
+    at::Half* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::Half* C,
+    int64_t ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    int heuristic,
+    const void* gelu_in,
+    const void* bias) {
+  bool save_gelu_in = gelu_in != nullptr;
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  constexpr int requestedAlgoCount = 5;
+  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
+  cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (save_gelu_in) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
+  }
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+    epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          // &heuristicResult.algo,
+                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
+                          &heuristicResult[heuristic].algo,
+                          // NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_bias_gelu_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::BFloat16* A,
+    int lda,
+    at::BFloat16* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::BFloat16* C,
+    int64_t ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    int heuristic,
+    const void* gelu_in,
+    const void* bias) {
+  bool save_gelu_in = gelu_in != nullptr;
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  constexpr int requestedAlgoCount = 5;
+  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
+  cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (save_gelu_in) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
+  }
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+    epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          // &heuristicResult.algo,
+                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
+                          &heuristicResult[heuristic].algo,
+                          // NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_bgradb_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::Half* A,
+    int lda,
+    at::Half* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::Half* C,
+    int ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    const void* bgrad) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  cublasLtMatmulHeuristicResult_t heuristicResult = {};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+      epilogue = CUBLASLT_EPILOGUE_BGRADB;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_bgradb_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::BFloat16* A,
+    int lda,
+    at::BFloat16* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::BFloat16* C,
+    int ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    bool use_bias,
+    const void* bgrad) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  cublasLtMatmulHeuristicResult_t heuristicResult = {};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (use_bias) {
+    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
+    if (status != CUBLAS_STATUS_SUCCESS) {
+      goto CLEANUP;
+    }
+      epilogue = CUBLASLT_EPILOGUE_BGRADB;
+  }
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_dgelu_bgradb_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::Half* A,
+    int lda,
+    at::Half* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::Half* C,
+    int64_t ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    int heuristic,
+    const void *gelu_in,
+    const void *bgrad) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  constexpr int requestedAlgoCount = 5;
+  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          &heuristicResult[heuristic].algo,
+                          // NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+int gemm_dgelu_bgradb_lt(
+    cublasLtHandle_t ltHandle,
+    cublasOperation_t transa,
+    cublasOperation_t transb,
+    int m,
+    int n,
+    int k,
+    const float *alpha, /* host pointer */
+    at::BFloat16* A,
+    int lda,
+    at::BFloat16* B,
+    int ldb,
+    const float *beta, /* host pointer */
+    at::BFloat16* C,
+    int64_t ldc,
+    void *workspace,
+    size_t workspaceSize,
+    cudaStream_t stream,
+    int heuristic,
+    const void *gelu_in,
+    const void *bgrad) {
+  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
+
+  cublasLtMatmulDescOpaque_t operationDesc = {};
+  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
+  cublasLtMatmulPreferenceOpaque_t preference = {};
+
+  int returnedResults                             = 0;
+  constexpr int requestedAlgoCount = 5;
+  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
+  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
+
+  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
+  // for details about defaults; here we just set the transforms for
+  // A and B.
+  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
+
+  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
+  if (status != CUBLAS_STATUS_SUCCESS) {
+    goto CLEANUP;
+  }
+
+  // Create matrix descriptors. Not setting any extra attributes.
+  status = cublasLtMatrixLayoutInit(
+    &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(
+    &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // Create preference handle; In general, extra attributes can be
+  // used here to disable tensor ops or to make sure algo selected
+  // will work with badly aligned A, B, C. However, for simplicity
+  // here we assume A,B,C are always well aligned (e.g., directly
+  // come from cudaMalloc)
+  status = cublasLtMatmulPreferenceInit(&preference);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+  status = cublasLtMatmulPreferenceSetAttribute(
+    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  // We just need the best available heuristic to try and run matmul.
+  // There is no guarantee that this will work. For example, if A is
+  // badly aligned, you can request more (e.g. 32) algos and try to
+  // run them one by one until something works.
+  status = cublasLtMatmulAlgoGetHeuristic(
+    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
+  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
+
+  if (returnedResults == 0) {
+    status = CUBLAS_STATUS_NOT_SUPPORTED;
+    goto CLEANUP;
+  }
+  status = cublasLtMatmul(ltHandle,
+                          &operationDesc,
+                          alpha,
+                          A,
+                          &Adesc,
+                          B,
+                          &Bdesc,
+                          beta,
+                          C,
+                          &Cdesc,
+                          C,
+                          &Cdesc,
+                          //&heuristicResult.algo,
+                          &heuristicResult[heuristic].algo,
+                          // NULL,
+                          workspace,
+                          workspaceSize,
+                          stream);
+
+CLEANUP:
+  // Descriptors are no longer needed as all GPU work was already
+  // enqueued.
+  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
+}
+
+#endif
+
+template <typename T>
+int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) {
+    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+    // Get the stream from cublas handle to reuse for biasReLU kernel.
+    cudaStream_t stream;
+    cublasGetStream(handle, &stream);
+    const float alpha          = 1.0;
+    const float beta_zero       = 0.0;
+    const float beta_one       = 1.0;
+    int status = 1;
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+    status = gemm_bias_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_T,
+    CUBLAS_OP_N,
+    out_features,
+    batch_size,
+    in_features,
+    &alpha, /* host pointer */
+    weight,
+    in_features,
+    input.data_ptr<T>(),
+    in_features,
+    &beta_zero, /* host pointer */
+    output.data_ptr<T>(),
+    out_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    true,
+    static_cast<const void*>(bias.data_ptr<T>()));
+#endif
+    if (status != 0){
+        output.copy_(bias);
+        status = gemm_bias(
+          handle,
+          CUBLAS_OP_T,
+          CUBLAS_OP_N,
+          out_features,
+          batch_size,
+          in_features,
+          &alpha,
+          weight,
+          in_features,
+          input.data_ptr<T>(),
+          in_features,
+          &beta_one,
+          output.data_ptr<T>(),
+          out_features);
+    }
+    return status;
+}
+
+    
+template <typename T>
+int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace) {
+    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+    // Get the stream from cublas handle to reuse for biasReLU kernel.
+    cudaStream_t stream;
+    cublasGetStream(handle, &stream);
+    const float alpha          = 1.0;
+    const float beta_zero      = 0.0;
+    const float beta           = residual ? 1.0 : 0.0;
+    int status = 1;
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+    status = gemm_bgradb_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_N,
+    CUBLAS_OP_T,
+    in_features,
+    out_features,
+    batch_size,
+    &alpha, /* host pointer */
+    input,
+    in_features,
+    d_output,
+    out_features,
+    &beta_zero, /* host pointer */
+    d_weight,
+    in_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    true,
+    static_cast<const void*>(d_bias));
+#endif
+    
+
+    if (status != 0){
+    
+        status = gemm_bias(
+          handle,
+          CUBLAS_OP_N,
+          CUBLAS_OP_T,
+          in_features,
+          out_features,
+          batch_size,
+          &alpha,
+          input,
+          in_features,
+          d_output,
+          out_features,
+          &beta_zero,
+          d_weight,
+          in_features);
+    }
+    
+    status = gemm_bias(
+      handle,
+      CUBLAS_OP_N,
+      CUBLAS_OP_N,
+      in_features,
+      batch_size,
+      out_features,
+      &alpha,
+      weight,
+      in_features,
+      d_output,
+      out_features,
+      &beta,
+      d_input,
+      in_features);
+    return status;
+
+}
+
+template <typename T>
+int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace) {
+    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+    // Get the stream from cublas handle to reuse for biasReLU kernel.
+    cudaStream_t stream;
+    cublasGetStream(handle, &stream);
+    const float alpha          = 1.0;
+    const float beta_zero      = 0.0;
+    int status = 1;
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+    status = gemm_bgradb_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_N,
+    CUBLAS_OP_T,
+    in_features,
+    out_features,
+    batch_size,
+    &alpha, /* host pointer */
+    input,
+    in_features,
+    d_output,
+    out_features,
+    &beta_zero, /* host pointer */
+    d_weight,
+    in_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    true,
+    static_cast<const void*>(d_bias));
+#endif
+
+
+    if (status != 0){
+
+        status = gemm_bias(
+          handle,
+          CUBLAS_OP_N,
+          CUBLAS_OP_T,
+          in_features,
+          out_features,
+          batch_size,
+          &alpha,
+          input,
+          in_features,
+          d_output,
+          out_features,
+          &beta_zero,
+          d_weight,
+          in_features);
+    }
+
+    return status;
+}
+
+template <typename T>
+int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) {
+    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+    // Get the stream from cublas handle to reuse for biasReLU kernel.
+    cudaStream_t stream;
+    cublasGetStream(handle, &stream);
+    const float alpha          = 1.0;
+    const float beta_zero       = 0.0;
+    int status = 1;
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+    status = gemm_bias_gelu_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_T,
+    CUBLAS_OP_N,
+    out_features,
+    batch_size,
+    in_features,
+    &alpha, /* host pointer */
+    weight,
+    in_features,
+    input,
+    in_features,
+    &beta_zero, /* host pointer */
+    output,
+    out_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    true,
+    heuristic,
+    static_cast<const void*>(gelu_in),
+    static_cast<const void*>(bias));
+    return status;
+#else
+    return 1;
+#endif
+}
+
+template <typename T>
+int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace) {
+    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+    // Get the stream from cublas handle to reuse for biasReLU kernel.
+    cudaStream_t stream;
+    cublasGetStream(handle, &stream);
+    const float alpha          = 1.0;
+    const float beta_zero      = 0.0;
+    const float beta           = residual ? 1.0 : 0.0;
+    int status = 1;
+#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
+//wgrad for first gemm
+    status = gemm_bgradb_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_N,
+    CUBLAS_OP_T,
+    hidden_features,
+    out_features,
+    batch_size,
+    &alpha, /* host pointer */
+    output1,
+    hidden_features,
+    d_output2,
+    out_features,
+    &beta_zero, /* host pointer */
+    d_weight2,
+    hidden_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    true,
+    static_cast<const void*>(d_bias2));
+//dgrad for second GEMM
+    status = gemm_dgelu_bgradb_lt(
+    (cublasLtHandle_t)handle,
+    CUBLAS_OP_N,
+    CUBLAS_OP_N,
+    hidden_features,
+    batch_size,
+    out_features,
+    &alpha, /* host pointer */
+    weight2,
+    hidden_features,
+    d_output2,
+    out_features,
+    &beta_zero, /* host pointer */
+    d_output1,
+    hidden_features,
+    lt_workspace,
+    1 << 22,
+    stream,
+    heuristic,
+    static_cast<const void*>(gelu_in),
+    static_cast<const void*>(d_bias1));
+//wgrad for the first GEMM
+    status = gemm_bias(
+      handle,
+      CUBLAS_OP_N,
+      CUBLAS_OP_T,
+      in_features,
+      hidden_features,
+      batch_size,
+      &alpha,
+      input,
+      in_features,
+      d_output1,
+      hidden_features,
+      &beta_zero,
+      d_weight1,
+      in_features);
+
+//dgrad for the first GEMM
+    status = gemm_bias(
+      handle,
+      CUBLAS_OP_N,
+      CUBLAS_OP_N,
+      in_features,
+      batch_size,
+      hidden_features,
+      &alpha,
+      weight1,
+      in_features,
+      d_output1,
+      hidden_features,
+      &beta,
+      d_input,
+      in_features);
+#endif
+    return status;
+
+}
+
+
+template int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
+template int linear_bias_forward_cuda<at::BFloat16>(at::Tensor input, at::BFloat16 *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
+
+template int linear_bias_backward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, bool residual, void *lt_workspace) ;
+template int linear_bias_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, at::BFloat16 *d_input, bool residual, void *lt_workspace) ;
+
+template int linear_bias_wgrad_cuda<at::Half>(at::Half *input, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace) ;
+template int linear_bias_wgrad_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace) ;
+
+template int linear_gelu_forward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *bias, int in_features, int batch_size, int out_features, int heuristic, at::Half *output, at::Half *gelu_in, void *lt_workspace) ;
+template int linear_gelu_forward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *bias, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *output, at::BFloat16 *gelu_in, void *lt_workspace) ;
+
+template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, bool residual, void *lt_workspace);
+template int linear_gelu_linear_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *gelu_in, at::BFloat16 *output1, at::BFloat16 *weight1, at::BFloat16 *weight2, at::BFloat16 *d_output1, at::BFloat16 *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::BFloat16 *d_weight1, at::BFloat16 *d_weight2, at::BFloat16 *d_bias1, at::BFloat16 *d_bias2, at::BFloat16 *d_input, bool residual, void *lt_workspace);

+ 42 - 0
csrc/fused_dense_lib/setup.py

@@ -0,0 +1,42 @@
+import os
+import subprocess
+
+import torch
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+
+def append_nvcc_threads(nvcc_extra_args):
+    _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
+        return nvcc_extra_args + ["--threads", "4"]
+    return nvcc_extra_args
+
+
+setup(
+    name='fused_dense_lib',
+    ext_modules=[
+        CUDAExtension(
+            name='fused_dense_lib',
+            sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
+            extra_compile_args={
+                               'cxx': ['-O3',],
+                               'nvcc': append_nvcc_threads(['-O3'])
+                               }
+            )
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+})
+

+ 6 - 0
csrc/layer_norm/README.md

@@ -0,0 +1,6 @@
+This CUDA extensions implements fused dropout + residual + LayerNorm, based on
+Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
+We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
+```sh
+cd csrc/layer_norm && pip install .
+```

+ 226 - 0
csrc/layer_norm/ln.h

@@ -0,0 +1,226 @@
+#pragma once
+
+#include <unordered_map>
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+
+#ifdef OLD_GENERATOR_PATH
+#include <ATen/CUDAGeneratorImpl.h>
+#else
+#include <ATen/cuda/CUDAGeneratorImpl.h>
+#endif
+
+namespace layer_norm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Params> 
+struct LaunchParams{
+
+    size_t elts_per_thread;
+    size_t workspace_bytes;
+    size_t barrier_size;
+
+    cudaDeviceProp * props;
+
+    cudaStream_t stream;
+
+    Params params;
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct ParamsBase {
+    ParamsBase()
+        : ctas_per_col(0)
+        , rows(0)
+        , cols(0)
+        , x(nullptr)
+        , mu(nullptr)
+        , rs(nullptr)
+        , gamma(nullptr)
+        , dropout_keep_p(1.f)
+        , dropout_scale(1.f)
+        , workspace(nullptr)
+        , barrier(nullptr)
+    {
+    }
+
+    // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
+    int ctas_per_col;
+
+    // Input is interpreted as matrix. We normalize across columns.
+    int rows;
+    int cols;
+
+    // Common data pointers.
+    void *x0;
+    void *x1;
+    void *x;
+    void *dmask;
+    void *mu;
+    void *rs;
+    void *gamma;
+    void *rowscale;
+
+    float dropout_keep_p;
+    float dropout_scale;
+
+    // Multi-CTA workspace in gmem.
+    void *workspace;
+
+    // Multi-CTA sync barriers in gmem.
+    int *barrier;
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct FwdParams : public ParamsBase {
+    FwdParams()
+        : ParamsBase()
+        , z(nullptr)
+        , beta(nullptr)
+        , epsilon(0.f)
+    {
+    }
+
+    // Output of LN FWD.
+    void *z;
+    void *beta;
+    float epsilon;
+
+    // Random state.
+    at::PhiloxCudaState philox_args;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct BwdParams : public ParamsBase {
+    BwdParams()
+        : ParamsBase()
+        , dz(nullptr)
+        , dx(nullptr)
+        , dbeta_part(nullptr)
+        , dgamma_part(nullptr)
+        , dx0(nullptr)
+        , dx1(nullptr)
+        , dbeta(nullptr)
+        , dgamma(nullptr)
+    {
+    }
+
+    // Input: gradient wrt. LN FWD output.
+    void *dz;
+    // Input: gradient wrt residual.
+    void *dx;
+
+    // Workspace for Wgrad pre-reduction.
+    void *dbeta_part;
+    void *dgamma_part;
+
+    // Output: Dgrad.
+    void *dx0;
+    void *dx1;
+    // Output: Wgrad.
+    void *dbeta;
+    void *dgamma;
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
+using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
+using FunctionKey = uint64_t;
+using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
+using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
+
+extern FwdRegistry FWD_FUNCS;
+extern BwdRegistry BWD_FUNCS;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+using fp32 = float;
+using fp16 = half;
+using bf16 = nv_bfloat16;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct TypeId{};
+
+template<>
+struct TypeId<fp16>{
+    constexpr static uint32_t Value = 0;
+};
+
+template<>
+struct TypeId<bf16>{
+    constexpr static uint32_t Value = 1;
+};
+
+template<>
+struct TypeId<fp32>{
+    constexpr static uint32_t Value = 2;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int S>
+struct Type2Key{
+    constexpr static uint32_t Value = TypeId<T>::Value << S;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct WeightType2Key : public Type2Key<T, 0>{};
+
+template<typename T>
+struct InputType2Key : public Type2Key<T, 2>{};
+
+template<typename T>
+struct ResidualType2Key : public Type2Key<T, 4>{};
+
+template<typename T>
+struct OutputType2Key : public Type2Key<T, 6>{};
+
+template<typename T>
+struct ComputeType2Key : public Type2Key<T, 8>{};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename W, typename I, typename R, typename O, typename C>
+struct Types2Key{
+    constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
+    constexpr static inline uint64_t get(const uint64_t hidden_size){
+        constexpr uint64_t type_key = Value;
+        return (type_key << 32) | hidden_size;
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
+struct FwdRegistrar{
+    FwdRegistrar(FwdFunction f){
+        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
+        FWD_FUNCS.insert({ key, f });
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
+struct BwdRegistrar{
+    BwdRegistrar(BwdFunction f){
+        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
+        BWD_FUNCS.insert({ key, f });
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace layer_norm

+ 455 - 0
csrc/layer_norm/ln_api.cpp

@@ -0,0 +1,455 @@
+#include <torch/extension.h>
+#include "ATen/cuda/CUDAContext.h"
+
+#include "ln.h"
+
+/*
+
+Supported Type combinations:
+
+input  residual   compute   weights   output
+============================================
+fp32     fp32      fp32      fp32      fp32
+fp16     fp32      fp32      fp32      fp16
+fp16     fp16      fp32      fp32      fp16
+bf16     fp32      fp32      fp32      bf16
+bf16     bf16      fp32      fp32      bf16
+fp16     fp16      fp32      fp16      fp16
+bf16     bf16      fp32      bf16      bf16
+
+Remarks:
+Output type = Input type
+Compute always in FP32
+
+*/
+
+namespace layer_norm {
+
+// Create registries and provide runtime versions of config hash functions.
+
+FwdRegistry FWD_FUNCS;
+BwdRegistry BWD_FUNCS;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+uint32_t get_type_id(torch::Dtype dtype){
+    if( dtype == torch::kFloat16 ) {
+        return TypeId<fp16>::Value;
+    } else if( dtype == torch::kBFloat16 ) {
+        return TypeId<bf16>::Value;
+    } else if( dtype == torch::kFloat32 ) {
+        return TypeId<fp32>::Value;
+    } else {
+        TORCH_CHECK(false, "Type not supported: ", dtype);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
+    using namespace layer_norm;
+    uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
+    uint64_t launcher_key = (type_key << 32) | hidden_size;
+    return launcher_key;
+}
+
+}  // namespace layer_norm
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
+    auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
+    if( iter != layer_norm::FWD_FUNCS.end() ) {
+        return iter->second;
+    } else {
+        TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
+    auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
+    if( iter != layer_norm::BWD_FUNCS.end() ) {
+        return iter->second;
+    } else {
+        TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input: BxSxhidden_size
+                                           c10::optional<const at::Tensor> &x1_,      // Residual: BxSxhidden_size
+                                           const at::Tensor &gamma,   // hidden_size
+                                           const at::Tensor &beta,   // hidden_size
+                                           c10::optional<const at::Tensor> &rowscale_,      // BxS
+                                           const float dropout_p,
+                                           const float epsilon,
+                                           c10::optional<at::Generator> gen_,
+                                           bool residual_in_fp32
+) {
+    auto itype = x0.scalar_type();
+    auto rtype = x1_.has_value()
+        ? x1_.value().scalar_type()
+        : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
+    auto wtype = gamma.scalar_type();
+    auto otype = itype;
+    auto ctype = torch::kFloat32;
+    auto mtype = torch::kUInt8;
+
+    TORCH_CHECK(beta.scalar_type() == wtype);
+
+    TORCH_CHECK(x0.is_cuda())
+    TORCH_CHECK(gamma.is_cuda())
+    TORCH_CHECK(beta.is_cuda())
+
+    TORCH_CHECK(x0.is_contiguous());
+    auto sizes = x0.sizes();
+    TORCH_CHECK(sizes.size() == 2);
+
+    const int rows = sizes[0];
+    const int cols = sizes[1];
+    auto hidden_size = gamma.numel();
+
+    if (x1_.has_value()) {
+        auto x1 = x1_.value();
+        TORCH_CHECK(x1.is_cuda())
+        TORCH_CHECK(x1.is_contiguous());
+        TORCH_CHECK(x1.sizes() == sizes);
+    }
+
+    if (rowscale_.has_value()) {
+        auto rowscale = rowscale_.value();
+        TORCH_CHECK(rowscale.is_cuda())
+        TORCH_CHECK(rowscale.is_contiguous());
+        TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
+        TORCH_CHECK(rowscale.scalar_type() == itype);
+    }
+
+    TORCH_CHECK(gamma.sizes() == beta.sizes());
+    TORCH_CHECK(hidden_size == cols);
+
+    TORCH_CHECK(epsilon >= 0.f);
+
+    auto opts = x0.options();
+
+    bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
+    at::Tensor x;
+    if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
+    at::Tensor dmask;
+    if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); };
+    auto z = torch::empty(sizes, opts.dtype(otype));
+
+    auto mu = torch::empty({ rows }, opts.dtype(ctype));
+    auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
+
+    layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
+
+    launch_params.props = at::cuda::getCurrentDeviceProperties();
+    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
+    TORCH_CHECK(dropout_p < 1.f);
+    launch_params.params.dropout_keep_p = 1.f - dropout_p;
+    launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
+    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
+
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+
+    // Request the kernel launcher.
+    auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+
+    // Query the kernel-specific launch parameters.
+    launcher(launch_params, true);
+
+    at::Tensor workspace, barrier;
+
+    // Set the kernel runtime parameters.
+    layer_norm::FwdParams &params = launch_params.params;
+    params.rows = rows;
+    params.cols = cols;
+    params.x0 = x0.data_ptr();
+    params.x = save_x ? x.data_ptr() : nullptr;
+    params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
+    params.mu = mu.data_ptr();
+    params.rs = rsigma.data_ptr();
+    params.gamma = gamma.data_ptr();
+    params.beta = beta.data_ptr();
+    params.z = z.data_ptr();
+    params.epsilon = epsilon;
+    params.dropout_scale = 1.f / (1.f - dropout_p);
+
+    if (dropout_p > 0.f) {
+        // number of times random will be generated per thread, to offset philox counter in thc random
+        // state
+        int64_t counter_offset = launch_params.elts_per_thread;
+
+        // See Note [Acquire lock when using random generators]
+        {
+            std::lock_guard<std::mutex> lock(gen->mutex_);
+            params.philox_args = gen->philox_cuda_state(counter_offset);
+        }
+    }
+
+    if( launch_params.barrier_size > 0 ) {
+        auto options = x0.options();
+        barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
+        workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
+        params.workspace = workspace.data_ptr();
+        params.barrier = barrier.data_ptr<int>();
+    }
+
+    // Launch the kernel.
+    launcher(launch_params, false);
+
+    return { z, x, dmask, mu, rsigma };
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidden_size
+                                           const at::Tensor &x,      // BxSxhidden_size
+                                           c10::optional<const at::Tensor> &dmask_,  // BxSxhidden_size
+                                           const at::Tensor &mu,     // BxS, FP32!
+                                           const at::Tensor &rsigma, // BxS, FP32!
+                                           const at::Tensor &gamma,   // hidden_size
+                                           c10::optional<const at::Tensor> &rowscale_,      // BxS
+                                           const float dropout_p,
+                                           const bool has_residual
+) {
+
+    auto itype = dz.scalar_type();
+    auto rtype = x.scalar_type();
+    auto wtype = gamma.scalar_type();
+    auto otype = itype;
+    auto ctype = torch::kFloat32;
+    auto mtype = torch::kUInt8;
+
+    if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
+
+    TORCH_CHECK(dz.dtype() == otype);
+    TORCH_CHECK(mu.dtype() == ctype);
+    TORCH_CHECK(rsigma.dtype() == ctype);
+
+    TORCH_CHECK(x.is_cuda());
+    TORCH_CHECK(dz.is_cuda());
+    TORCH_CHECK(mu.is_cuda());
+    TORCH_CHECK(rsigma.is_cuda());
+    TORCH_CHECK(gamma.is_cuda());
+
+    TORCH_CHECK(x.is_contiguous());
+    TORCH_CHECK(dz.is_contiguous());
+
+    auto sizes = x.sizes();
+    TORCH_CHECK(sizes.size() == 2);
+    TORCH_CHECK(dz.sizes() == sizes);
+    auto rows = sizes[0];
+    auto cols = sizes[1];
+
+    if (dmask_.has_value()) {
+        auto dmask = dmask_.value();
+        TORCH_CHECK(dmask.dtype() == mtype);
+        TORCH_CHECK(dmask.is_cuda());
+        TORCH_CHECK(dmask.is_contiguous());
+        TORCH_CHECK(dmask.sizes() == sizes);
+    }
+
+    if (rowscale_.has_value()) {
+        auto rowscale = rowscale_.value();
+        TORCH_CHECK(rowscale.is_cuda())
+        TORCH_CHECK(rowscale.is_contiguous());
+        TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
+        TORCH_CHECK(rowscale.scalar_type() == itype);
+    }
+
+    auto hidden_size = gamma.numel();
+
+    TORCH_CHECK(mu.numel() == rows);
+    TORCH_CHECK(mu.sizes() == rsigma.sizes());
+
+    TORCH_CHECK(gamma.numel() == cols);
+
+    auto opts = x.options();
+
+    auto dx0 = torch::empty_like(x, opts.dtype(itype));
+    at::Tensor dx1;
+    if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
+    auto dgamma = torch::empty_like(gamma);
+    auto dbeta = torch::empty_like(gamma);
+
+    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
+    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
+    launch_params.props = at::cuda::getCurrentDeviceProperties();
+    TORCH_CHECK(dropout_p < 1.f);
+    launch_params.params.dropout_keep_p = 1.f - dropout_p;
+    launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
+    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
+
+    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+
+    launcher(launch_params, true, /*prenorm=*/false);
+
+    auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    at::Tensor workspace, barrier;
+
+    layer_norm::BwdParams &params = launch_params.params;
+    params.rows = rows;
+    params.cols = cols;
+    params.x = x.data_ptr();
+    params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
+    params.mu = mu.data_ptr();
+    params.rs = rsigma.data_ptr();
+    params.gamma = gamma.data_ptr();
+    params.dz = dz.data_ptr();
+    params.dx0 = dx0.data_ptr();
+    params.dbeta = dbeta.data_ptr();
+    params.dgamma = dgamma.data_ptr();
+    params.dbeta_part = dbeta_part.data_ptr();
+    params.dgamma_part = dgamma_part.data_ptr();
+    params.dropout_scale = 1.f / (1.f - dropout_p);
+
+    if( launch_params.barrier_size > 0 ) {
+        // TODO Any way to avoid this?
+        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
+        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
+        params.workspace = workspace.data_ptr();
+        params.barrier = barrier.data_ptr<int>();
+    }
+
+    launcher(launch_params, false, /*prenorm=*/false);
+
+    return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz,     // BxSxhidden_size
+                                                   const at::Tensor &dx,     // BxSxhidden_size
+                                                   const at::Tensor &x,      // BxSxhidden_size
+                                                   c10::optional<const at::Tensor> &dmask_,  // BxSxhidden_size
+                                                   const at::Tensor &mu,     // BxS, FP32!
+                                                   const at::Tensor &rsigma, // BxS, FP32!
+                                                   const at::Tensor &gamma,   // hidden_size
+                                                   c10::optional<const at::Tensor> &rowscale_,      // BxS
+                                                   const float dropout_p,
+                                                   const bool has_residual
+) {
+
+    auto itype = dz.scalar_type();
+    auto rtype = x.scalar_type();
+    auto wtype = gamma.scalar_type();
+    auto otype = itype;
+    auto ctype = torch::kFloat32;
+    auto mtype = torch::kUInt8;
+
+    if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
+
+    TORCH_CHECK(dz.dtype() == otype);
+    TORCH_CHECK(dx.dtype() == rtype);
+    TORCH_CHECK(mu.dtype() == ctype);
+    TORCH_CHECK(rsigma.dtype() == ctype);
+
+    TORCH_CHECK(x.is_cuda());
+    TORCH_CHECK(dz.is_cuda());
+    TORCH_CHECK(dx.is_cuda());
+    TORCH_CHECK(mu.is_cuda());
+    TORCH_CHECK(rsigma.is_cuda());
+    TORCH_CHECK(gamma.is_cuda());
+
+    TORCH_CHECK(x.is_contiguous());
+    TORCH_CHECK(dz.is_contiguous());
+    TORCH_CHECK(dx.is_contiguous());
+
+    auto sizes = x.sizes();
+    TORCH_CHECK(sizes.size() == 2);
+    TORCH_CHECK(dz.sizes() == sizes);
+    TORCH_CHECK(dx.sizes() == sizes);
+    auto rows = sizes[0];
+    auto cols = sizes[1];
+
+    if (dmask_.has_value()) {
+        auto dmask = dmask_.value();
+        TORCH_CHECK(dmask.dtype() == mtype);
+        TORCH_CHECK(dmask.is_cuda());
+        TORCH_CHECK(dmask.is_contiguous());
+        TORCH_CHECK(dmask.sizes() == sizes);
+    }
+
+    if (rowscale_.has_value()) {
+        auto rowscale = rowscale_.value();
+        TORCH_CHECK(rowscale.is_cuda())
+        TORCH_CHECK(rowscale.is_contiguous());
+        TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
+        TORCH_CHECK(rowscale.scalar_type() == itype);
+    }
+
+    auto hidden_size = gamma.numel();
+
+    TORCH_CHECK(mu.numel() == rows);
+    TORCH_CHECK(mu.sizes() == rsigma.sizes());
+
+    TORCH_CHECK(gamma.numel() == cols);
+
+    auto opts = x.options();
+
+    auto dx0 = torch::empty_like(x, opts.dtype(itype));
+    at::Tensor dx1;
+    if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
+    auto dgamma = torch::empty_like(gamma);
+    auto dbeta = torch::empty_like(gamma);
+
+    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
+    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
+    launch_params.props = at::cuda::getCurrentDeviceProperties();
+    TORCH_CHECK(dropout_p < 1.f);
+    launch_params.params.dropout_keep_p = 1.f - dropout_p;
+    launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
+    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
+
+    // TODO: how to set template param for launcher
+    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
+
+    launcher(launch_params, true, /*prenorm=*/true);
+
+    auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
+    at::Tensor workspace, barrier;
+
+    layer_norm::BwdParams &params = launch_params.params;
+    params.rows = rows;
+    params.cols = cols;
+    params.x = x.data_ptr();
+    params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
+    params.mu = mu.data_ptr();
+    params.rs = rsigma.data_ptr();
+    params.gamma = gamma.data_ptr();
+    params.dz = dz.data_ptr();
+    params.dx = dx.data_ptr();
+    params.dx0 = dx0.data_ptr();
+    params.dbeta = dbeta.data_ptr();
+    params.dgamma = dgamma.data_ptr();
+    params.dbeta_part = dbeta_part.data_ptr();
+    params.dgamma_part = dgamma_part.data_ptr();
+    params.dropout_scale = 1.f / (1.f - dropout_p);
+
+    if( launch_params.barrier_size > 0 ) {
+        // TODO Any way to avoid this?
+        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
+        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
+        params.workspace = workspace.data_ptr();
+        params.barrier = barrier.data_ptr<int>();
+    }
+
+    launcher(launch_params, false, /*prenorm=*/true);
+
+    return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
+}
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.doc() = "CUDA DropoutAddLayerNorm";
+  m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
+  m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
+  m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
+}

+ 328 - 0
csrc/layer_norm/ln_bwd_kernels.cuh

@@ -0,0 +1,328 @@
+#pragma once
+
+namespace layer_norm {
+
+template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale>
+__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
+void ln_bwd_kernel(layer_norm::BwdParams params) {
+
+    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
+    enum { WARPS_M = Ktraits::WARPS_M };
+    enum { WARPS_N = Ktraits::WARPS_N };
+    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
+    enum { COLS = Ktraits::COLS };
+    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
+    enum { LDGS = Ktraits::LDGS };
+    enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
+    enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
+    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
+
+    using input_t = typename Ktraits::input_t;
+    using compute_t = typename Ktraits::compute_t;
+    using index_t = typename Ktraits::index_t;
+    using mask_t = typename Ktraits::mask_t;
+    using Ivec = typename Ktraits::Ivec;
+    using Rvec = typename Ktraits::Rvec;
+    using Ovec = typename Ktraits::Ovec;
+    using Wvec = typename Ktraits::Wvec;
+    using Cvec = typename Ktraits::Cvec;
+    using Mvec = typename Ktraits::Mvec;
+    using Reducer = typename Ktraits::Reducer;
+    using reduce_t = typename Reducer::Type;
+
+    extern __shared__ char smem_[];
+
+    const index_t tidx = threadIdx.x;
+    const index_t bidn = blockIdx.x % CTAS_PER_ROW;
+    const index_t bidm = blockIdx.x / CTAS_PER_ROW;
+    const index_t lane = tidx % THREADS_PER_WARP;
+    const index_t warp = tidx / THREADS_PER_WARP;
+    const index_t warp_m = warp / Ktraits::WARPS_N;
+    const index_t warp_n = warp % Ktraits::WARPS_N;
+    const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
+
+    const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
+    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
+
+    static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
+
+    Cvec dzy_sum[LDGS];
+    Cvec dz_sum[LDGS];
+
+    memset(dzy_sum, 0, sizeof(dzy_sum));
+    memset(dz_sum, 0, sizeof(dz_sum));
+
+    compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
+    char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
+
+    Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
+
+    Sum<reduce_t> sum;
+
+    constexpr float rn = 1.f / float(COLS);
+    Wvec gamma[LDGS];
+    index_t idx = c;
+    #pragma unroll
+    for( int it = 0; it < LDGS; it++ ) {
+        gamma[it].load_from(params.gamma, idx);
+        idx += Ktraits::VEC_COLS_PER_LDG;
+    }
+    // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
+    // last blocks with syncthreads!
+    // grid stride over rows
+    #pragma unroll 1
+    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
+        const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
+        const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
+        const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f;
+        Mvec dmask[LDGS];
+        Rvec dx[LDGS];
+        compute_t dy[LDGS * NUM_ELTS];
+        compute_t y[LDGS * NUM_ELTS];
+        compute_t mdy_local = 0.f;
+        compute_t mdyy_local = 0.f;
+        index_t idx = row * Ktraits::VEC_COLS + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            Rvec x;
+            Ovec dz;
+            dz.load_from(params.dz, idx);
+            if (Prenorm) { dx[it].load_from(params.dx, idx); }
+            x.load_from(params.x, idx);
+            if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
+            idx += Ktraits::VEC_COLS_PER_LDG;
+            #pragma unroll
+            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                compute_t x_tmp = x.data.elt[jt];
+                compute_t y_tmp = rs_r * (x_tmp - mu_r);
+                compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
+                dy_tmp *= compute_t(dz.data.elt[jt]);
+                compute_t dz_tmp = dz.data.elt[jt];
+
+                mdy_local += dy_tmp;
+                mdyy_local += dy_tmp * y_tmp;
+
+                dy[it * NUM_ELTS + jt] = dy_tmp;
+                y[it * NUM_ELTS + jt] = y_tmp;
+
+                dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
+                dz_sum[it].data.elt[jt] += dz_tmp;
+            }
+        }
+
+        reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
+        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
+        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
+
+        idx = row * Ktraits::VEC_COLS + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            Ivec dx0;
+            Rvec dx1;
+            #pragma unroll
+            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                compute_t dy_tmp = dy[it * NUM_ELTS + jt];
+                compute_t y_tmp = y[it * NUM_ELTS + jt];
+                compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
+                compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
+                if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
+                compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res;
+                if (Is_dropout) {
+                    dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
+                } else {
+                    dx0.data.elt[jt] = dx0_tmp_res;
+                }
+            }
+            if (Has_residual) { dx1.store_to(params.dx1, idx); }
+            dx0.store_to(params.dx0, idx);
+            idx += Ktraits::VEC_COLS_PER_LDG;
+        }
+
+    }  // end: grid stride loop
+
+    if( WARPS_M == 1 ) {
+        idx = r * Ktraits::VEC_COLS + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            dz_sum[it].store_to(params.dbeta_part, idx);
+            dzy_sum[it].store_to(params.dgamma_part, idx);
+            idx += Ktraits::VEC_COLS_PER_LDG;
+        }
+    } else {
+        static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
+        // Finalize reduction of part dgamma and dbeta for this CTA
+        // by reducing over the rows held across the WARPS_M warps
+
+        // Assumption: blockSize divides hidden size.
+        enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
+        static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
+
+        idx = warp_m * Ktraits::VEC_COLS + tid_r;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            dz_sum[it].store_to(smem_wgrad, idx);
+            idx += THREADS_PER_ROW;
+        }
+        __syncthreads();
+        compute_t cta_dz_sum[NUM_RES];
+        memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
+        for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+            for( int jt = 0; jt < NUM_RES; jt++ ) {
+                cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+            }
+        }
+        __syncthreads();
+
+        idx = warp_m * Ktraits::VEC_COLS + tid_r;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            dzy_sum[it].store_to(smem_wgrad, idx);
+            idx += THREADS_PER_ROW;
+        }
+        __syncthreads();
+        compute_t cta_dzy_sum[NUM_RES];
+        memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
+        for( int it = 0; it < ROWS_PER_CTA; it++ ) {
+            for( int jt = 0; jt < NUM_RES; jt++ ) {
+                cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
+            }
+        }
+
+        compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
+        for( int jt = 0; jt < NUM_RES; jt++ ) {
+            *dgamma_part = cta_dzy_sum[jt];
+            dgamma_part += Ktraits::THREADS_PER_CTA;
+        }
+
+        compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
+        for( int jt = 0; jt < NUM_RES; jt++ ) {
+            *dbeta_part = cta_dz_sum[jt];
+            dbeta_part += Ktraits::THREADS_PER_CTA;
+        }
+    }
+}
+
+template<typename Kernel_traits>
+__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
+void ln_bwd_finalize_kernel(BwdParams params)
+{
+
+    using compute_t = typename Kernel_traits::compute_t;
+    using weight_t = typename Kernel_traits::weight_t;
+    using index_t = typename Kernel_traits::index_t;
+    using Reducer = typename Kernel_traits::Reducer;
+    using reduce_t = typename Reducer::Type;
+
+    Sum<reduce_t> sum;
+    enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
+    enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
+
+    __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
+
+    constexpr uint32_t bidm = 0;
+
+    const uint32_t bidn = blockIdx.x;
+    const uint32_t tidx = threadIdx.x;
+    const uint32_t warp = tidx / THREADS_PER_WARP;
+    const uint32_t lane = tidx % THREADS_PER_WARP;
+
+    Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
+
+    const uint32_t c = bidn * THREADS_PER_WARP + lane;
+    const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
+    constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
+    for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
+        // Each thread sums over NUM_ELT columns.
+        Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
+        memset(&dgamma_local, 0, sizeof(dgamma_local));
+        memset(&dbeta_local, 0, sizeof(dbeta_local));
+        for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
+            index_t idx = row * Kernel_traits::COLS + col;
+
+            Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
+            dbeta_part.load_from(params.dbeta_part, idx);
+            dgamma_part.load_from(params.dgamma_part, idx);
+            #pragma unroll
+            for( int it = 0; it < NUM_ELT; it++ ) {
+                dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
+                dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
+            }
+        }
+
+        void * smem_gamma = smem_;
+        void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
+
+        const int write_row = warp;
+        const int write_col = lane ^ write_row;
+        const int write_idx = write_row * THREADS_PER_WARP + write_col;
+
+        dgamma_local.store_to(smem_gamma, write_idx);
+        dbeta_local.store_to(smem_beta, write_idx);
+
+        __syncthreads();
+
+        // It would be probably safe to reuse the first row of smem_beta and smem_gamma
+        void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
+        void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
+
+
+        // More than one iter iff ROWS_PER_CTA < 32.
+        for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
+            const int read_row = lane;
+            const int read_col = w ^ read_row;
+            const int read_idx = read_row * THREADS_PER_WARP + read_col;
+
+            memset(&dbeta_local, 0, sizeof(dbeta_local));
+            memset(&dgamma_local, 0, sizeof(dgamma_local));
+
+            // Load beta and gamma transposed 
+            if(read_row < Kernel_traits::ROWS_PER_CTA){
+                dbeta_local.load_from(smem_beta, read_idx);
+                dgamma_local.load_from(smem_gamma, read_idx);
+            }
+
+            // Call reducer on the loaded value(s) and convert.
+            #pragma unroll
+            for( int it = 0; it < NUM_ELT; it++ ) {
+                compute_t b_i = dbeta_local.data.elt[it];
+                compute_t g_i = dgamma_local.data.elt[it];
+                b_i = reducer.allreduce(b_i, sum);
+                g_i = reducer.allreduce(g_i, sum);
+
+                dgamma_local.data.elt[it] = g_i;
+                dbeta_local.data.elt[it] = b_i;
+            }
+
+            // Leader stores the result at the current column.
+            if(lane == 0){
+                dgamma_local.store_to(smem_gamma_out, w);
+                dbeta_local.store_to(smem_beta_out, w);
+            }
+
+        }
+
+        // All writes done.
+        __syncthreads();
+
+        // Pack and store: 2-wide stores with half the threads.
+        if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
+
+            using src_t = typename TypeToVec2<compute_t>::Type;
+            using dst_t = typename TypeToVec2<weight_t>::Type;
+            Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
+            Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
+
+            dgamma_vec2.load_from(smem_gamma_out, lane);
+            dbeta_vec2.load_from(smem_beta_out, lane);
+            #pragma unroll
+            for( int it = 0; it < NUM_ELT; it++ ) {
+                dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
+                dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
+            }
+            dgamma_out2.store_to(params.dgamma, col_out);
+            dbeta_out2.store_to(params.dbeta, col_out);
+
+        }
+    }
+}
+}  // namespace layer_norm

+ 325 - 0
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu

@@ -0,0 +1,325 @@
+#include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "ln_bwd_kernels.cuh"
+#include "static_switch.h"
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE, 
+    int CTAS_PER_ROW, 
+    int WARPS_M, 
+    int WARPS_N, 
+    int BYTES_PER_LDG_MAIN,
+    int BYTES_PER_LDG_FINAL
+>
+void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG_MAIN
+                                        >;
+    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
+    bool has_residual = launch_params.params.dx1 != nullptr;
+    bool has_rowscale = launch_params.params.rowscale != nullptr;
+    BOOL_SWITCH(prenorm, PrenormConst, [&] {
+        BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
+            BOOL_SWITCH(has_residual, HasResidualConst, [&] {
+                BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
+                    auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
+                    if( configure_params ) {
+                        int ctas_per_sm;
+                        cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
+                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                        launch_params.barrier_size = 0;
+                        launch_params.workspace_bytes = 0;
+                        if(Kernel_traits::CTAS_PER_ROW > 1) {
+                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                            launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                          * Kernel_traits::WARPS_M
+                                                          * Kernel_traits::CTAS_PER_ROW
+                                                          * sizeof(typename Kernel_traits::reduce_t)
+                                                          * 2;
+                        }
+                        return;
+                    }
+
+                    if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
+                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
+                    }
+                    auto stream = launch_params.stream;
+                    auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
+                    } else {
+                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                        dim3 block(Kernel_traits::THREADS_PER_CTA);
+                        void *params_ = (void *)&launch_params.params;
+                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
+                    }
+
+                    using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
+                                                                              weight_t,
+                                                                              input_t,
+                                                                              residual_t,
+                                                                              output_t,
+                                                                              compute_t,
+                                                                              index_t,
+                                                                              32 * 32,  // THREADS_PER_CTA
+                                                                              BYTES_PER_LDG_FINAL>;
+
+                    auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
+                    kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
+                });
+            });
+        });
+    });
+}
+
+// Create backward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
+
+REGISTER_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
+REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+
+REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1,  4, 4);
+REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1,  4, 4);
+
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
+REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
+
+// TD [2022-04-22] Disable most of these to speed up compile time
+
+// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);
+
+// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4,  4, 4);
+// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4,  4, 4);
+// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);
+
+// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
+// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
+// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
+// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
+// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
+
+// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
+
+// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4,  4, 4);
+// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4,  4, 4);
+// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4,  8, 4);
+
+// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4,  8, 4);
+// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
+// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
+// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
+// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
+// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
+// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
+
+// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
+
+// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
+// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);

+ 302 - 0
csrc/layer_norm/ln_fwd_cuda_kernel.cu

@@ -0,0 +1,302 @@
+#include "ln.h"
+#include "ln_utils.cuh"
+#include "ln_kernel_traits.h"
+#include "ln_fwd_kernels.cuh"
+#include "static_switch.h"
+
+using namespace layer_norm;
+
+template<
+    typename weight_t,
+    typename input_t,
+    typename residual_t,
+    typename output_t,
+    typename compute_t,
+    typename index_t,
+    int HIDDEN_SIZE, 
+    int CTAS_PER_ROW, 
+    int WARPS_M, 
+    int WARPS_N, 
+    int BYTES_PER_LDG
+>
+void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
+
+    using Kernel_traits = Kernel_traits<weight_t,
+                                        input_t,
+                                        residual_t,
+                                        output_t,
+                                        compute_t,
+                                        index_t,
+                                        HIDDEN_SIZE,
+                                        CTAS_PER_ROW,
+                                        WARPS_M,
+                                        WARPS_N,
+                                        BYTES_PER_LDG
+                                        >;
+    bool has_residual = launch_params.params.x1 != nullptr;
+    bool has_rowscale = launch_params.params.rowscale != nullptr;
+    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
+        BOOL_SWITCH(has_residual, HasResidualConst, [&] {
+            BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
+                auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
+                if( configure_params ) {
+                    int ctas_per_sm;
+                    cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
+                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
+                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
+                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
+                    launch_params.barrier_size = 0;
+                    launch_params.workspace_bytes = 0;
+                    if(Kernel_traits::CTAS_PER_ROW > 1) {
+                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
+                        launch_params.workspace_bytes = launch_params.params.ctas_per_col
+                                                      * Kernel_traits::WARPS_M
+                                                      * Kernel_traits::CTAS_PER_ROW
+                                                      * sizeof(typename Kernel_traits::Stats::stats_t)
+                                                      * 2;
+                    }
+                    return;
+                }
+
+                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
+                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
+                }
+                auto stream = launch_params.stream;
+                auto ctas_per_col = launch_params.params.ctas_per_col;
+
+                if( Kernel_traits::CTAS_PER_ROW == 1 ) {
+                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
+                } else {
+                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
+                    dim3 block(Kernel_traits::THREADS_PER_CTA);
+                    void *params_ = (void *)&launch_params.params;
+                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
+                }
+            });
+        });
+    });
+}
+
+// Create forward launch function and register. Macro signature:
+//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
+
+REGISTER_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1,  4);
+REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1,  4);
+
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+
+// TD [2022-04-22] Disable most of these to speed up compile time
+
+// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
+// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
+// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
+// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
+// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
+
+// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4,  4);
+
+// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4,  4);
+
+// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4,  8);
+
+// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4,  8);
+// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4,  4);
+
+// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4,  4);
+// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4,  4);
+
+// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
+
+// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
+// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);

+ 159 - 0
csrc/layer_norm/ln_fwd_kernels.cuh

@@ -0,0 +1,159 @@
+#pragma once
+
+#ifdef OLD_GENERATOR_PATH
+#include <ATen/CUDAGeneratorImpl.h>
+#else
+#include <ATen/cuda/CUDAGeneratorImpl.h>
+#endif
+
+#include <ATen/cuda/detail/UnpackRaw.cuh>  // For at::cuda::philox::unpack
+#include <curand_kernel.h>
+
+#include "ln.h"
+
+namespace layer_norm {
+
+template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_rowscale>
+__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) 
+void ln_fwd_kernel(FwdParams params) {
+
+    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
+    enum { WARPS_N = Ktraits::WARPS_N };
+    enum { WARPS_M = Ktraits::WARPS_M };
+    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
+    enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
+    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
+    enum { LDGS = Ktraits::LDGS };
+    enum { NUM_ELTS = Ktraits::NUM_ELTS };
+    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
+
+    using input_t = typename Ktraits::input_t;
+    using residual_t = typename Ktraits::residual_t;
+    using output_t = typename Ktraits::output_t;
+    using index_t = typename Ktraits::index_t;
+    using compute_t = typename Ktraits::compute_t;
+    using mask_t = typename Ktraits::mask_t;
+    using Ivec = typename Ktraits::Ivec;
+    using Rvec = typename Ktraits::Rvec;
+    using Ovec = typename Ktraits::Ovec;
+    using Wvec = typename Ktraits::Wvec;
+    using Cvec = typename Ktraits::Cvec;
+    using Mvec = typename Ktraits::Mvec;
+
+    using Stats = typename Ktraits::Stats;
+    using stats_t = typename Stats::stats_t;
+
+    constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
+
+    extern __shared__ char smem_[];
+
+    const index_t tidx = threadIdx.x;
+    const index_t bidn = blockIdx.x % CTAS_PER_ROW;
+    const index_t bidm = blockIdx.x / CTAS_PER_ROW;
+    const index_t lane = tidx % THREADS_PER_WARP;
+    const index_t warp = tidx / THREADS_PER_WARP;
+    const index_t warp_m = warp / WARPS_N;
+    const index_t warp_n = warp % WARPS_N;
+
+    const index_t r = bidm * ROWS_PER_CTA + warp_m;
+    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
+
+    Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
+
+    compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
+    compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
+
+    const input_t *rowscale = static_cast<input_t *>(params.rowscale);
+
+    // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
+    curandStatePhilox4_32_10_t state;
+    if (Is_dropout) {
+        auto seeds = at::cuda::philox::unpack(params.philox_args);
+        const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
+        curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
+    }
+
+    Wvec gamma[LDGS];
+    Wvec beta[LDGS];
+    index_t idx = c;
+    #pragma unroll
+    for( int it = 0; it < LDGS; it++ ) {
+        gamma[it].load_from(params.gamma, idx);
+        beta[it].load_from(params.beta, idx);
+        idx += VEC_COLS_PER_LDG;
+    }
+
+    constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
+
+    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
+        const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f;
+        index_t idx = row * Ktraits::VEC_COLS + c;
+        compute_t xf[LDGS * NUM_ELTS];
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            Ivec x0;
+            Rvec x1;
+            Rvec x;
+            Mvec dmask;
+            x0.load_from(params.x0, idx);
+            if (Has_residual) { x1.load_from(params.x1, idx); }
+            #pragma unroll
+            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
+                // the more efficient curand_uniform4.
+                mask_t keep = true;
+                if (Is_dropout) {
+                    float rand = curand_uniform(&state);
+                    keep = mask_t(rand <= params.dropout_keep_p);
+                }
+                compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]);
+                compute_t x_ij;
+                if (Has_residual) {
+                    compute_t x1_ij = compute_t(x1.data.elt[jt]);
+                    x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
+                } else  {
+                    x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
+                }
+                if (save_x) { x.data.elt[jt] = x_ij; }
+                xf[it * NUM_ELTS + jt] = x_ij;
+                if (Is_dropout) { dmask.data.elt[jt] = keep; }
+            }
+            if (save_x) { x.store_to(params.x, idx); }
+            if (Is_dropout) { dmask.store_to(params.dmask, idx); }
+            idx += VEC_COLS_PER_LDG;
+        }
+
+        stats_t s = stats.compute(xf, rn);
+
+        compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
+        compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
+
+        if( bidn == 0 && warp_n == 0 && lane == 0 ) {
+            mu_ptr[row] = mu;
+        }
+
+        compute_t rs = rsqrtf(rn * m2 + params.epsilon);
+
+        if( bidn == 0 && warp_n == 0 && lane == 0 ) {
+            rs_ptr[row] = rs;
+        }
+
+        idx = row * Ktraits::VEC_COLS + c;
+        #pragma unroll
+        for( int it = 0; it < LDGS; it++ ) {
+            Ovec z;
+            #pragma unroll
+            for( int jt = 0; jt < NUM_ELTS; jt++ ) {
+                output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
+                output_t g_ij = gamma[it].data.elt[jt];
+                output_t b_ij = beta[it].data.elt[jt];
+                z.data.elt[jt] = (g_ij * y_ij + b_ij);
+            }
+            z.store_to(params.z, idx);
+            idx += VEC_COLS_PER_LDG;
+        }
+
+    }
+}
+
+}  // namespace layer_norm

+ 170 - 0
csrc/layer_norm/ln_kernel_traits.h

@@ -0,0 +1,170 @@
+#pragma once
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace layer_norm {
+template<
+    uint32_t HIDDEN_SIZE_,
+    typename weight_t_,
+    typename input_t_,
+    typename residual_t_,
+    typename output_t_,
+    typename compute_t_,
+    typename index_t_,
+    uint32_t THREADS_PER_CTA_
+>
+struct Kernel_traits_base {
+
+    using weight_t = weight_t_;
+    using input_t = input_t_;
+    using residual_t = residual_t_;
+    using output_t = output_t_;
+    using compute_t = compute_t_;
+    using index_t = index_t_;
+
+    enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
+    enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
+    enum { THREADS_PER_WARP = 32 };
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<
+    uint32_t HIDDEN_SIZE_,
+    typename weight_t_,
+    typename input_t_,
+    typename residual_t_,
+    typename output_t_,
+    typename compute_t_,
+    typename index_t_,
+    uint32_t THREADS_PER_CTA_,
+    uint32_t BYTES_PER_LDG_,
+    typename Base = Kernel_traits_base<HIDDEN_SIZE_,
+                                        weight_t_,
+                                        input_t_,
+                                        residual_t_,
+                                        output_t_,
+                                        compute_t_,
+                                        index_t_,
+                                        THREADS_PER_CTA_>
+>
+struct Kernel_traits_finalize : public Base {
+    enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
+    static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
+    // Bytes per global load from the input. 
+    enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
+    // Number of elements fetched by a global load.
+    enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
+    // Bytes per global store of the weights.
+    enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
+    static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
+    static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
+    // The total number of BYTES_PER_LDG-wide words in a hidden vector.
+    enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
+    static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
+
+    // Shared memory size to transpose the CTA result.
+    enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
+    // Shared memory size to coalsece the CTA result.
+    enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
+    // Shared memory requirement per CTA. 
+    enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
+
+    // The type of the reducer.
+    using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
+
+    // Condition for the whole CTA to participate in syncthreads.
+    static_assert(COLS % Base::THREADS_PER_WARP == 0);
+    enum { CTAS = COLS / Base::THREADS_PER_WARP };
+}; 
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+template<
+    typename weight_t_,
+    typename input_t_,
+    typename residual_t_,
+    typename output_t_,
+    typename compute_t_,
+    typename index_t_,
+    uint32_t HIDDEN_SIZE_, 
+    uint32_t CTAS_PER_ROW_, 
+    uint32_t WARPS_M_, 
+    uint32_t WARPS_N_, 
+    uint32_t BYTES_PER_LDG_ = 16,
+    typename Base = Kernel_traits_base<
+        HIDDEN_SIZE_,
+        weight_t_, 
+        input_t_,
+        residual_t_,
+        output_t_, 
+        compute_t_, 
+        index_t_, 
+        WARPS_M_*WARPS_N_*THREADS_PER_WARP
+        >
+>
+struct Kernel_traits : public Base {
+
+    using input_t = typename Base::input_t;
+    using residual_t = typename Base::residual_t;
+    using weight_t = typename Base::weight_t;
+    using compute_t = typename Base::compute_t;
+    using output_t = typename Base::output_t;
+    using index_t = typename Base::index_t;
+    // using mask_t = unsigned char;
+    using mask_t = bool;
+
+    enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
+    enum { WARPS_M = WARPS_M_ };
+    enum { WARPS_N = WARPS_N_ };
+    enum { COLS = HIDDEN_SIZE_ };
+    enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
+    enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
+    enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
+
+    enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
+    enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
+    enum { ROWS_PER_CTA = WARPS_M };
+
+    enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
+    enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
+    // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
+    enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
+    static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
+
+    using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
+    using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>; 
+
+    enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
+    enum { SMEM_BYTES = SMEM_BYTES_DGRAD  + SMEM_BYTES_WGRAD };
+
+    using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
+    using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
+    using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
+    using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
+    using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
+    using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
+    enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
+
+    // Assume that each thread can handle the same number of elements in the output and weights as in the input.
+    static_assert(sizeof(input_t) == sizeof(output_t));
+    static_assert(sizeof(input_t) <= sizeof(residual_t));
+    // The number of columns fetched per load from input: one per thread.
+    enum { VEC_COLS_PER_LDG =  CTAS_PER_ROW * THREADS_PER_ROW };
+    // The total number of vectorized loads/stores per hidden vector.
+    enum { VEC_COLS = COLS / ELTS_PER_LDG };
+    // The number of loads per thread for the input.
+    enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
+    static_assert(LDGS * VEC_COLS_PER_LDG  == VEC_COLS);
+    //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
+
+    using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
+    enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace layer_norm

+ 734 - 0
csrc/layer_norm/ln_utils.cuh

@@ -0,0 +1,734 @@
+#pragma once
+
+#include <cassert>
+
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+
+#include "ln.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+constexpr uint32_t THREADS_PER_WARP = 32;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline void check_cuda_(cudaError_t status, const char *file, int line) {
+    if( status != cudaSuccess ) {
+        fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
+        exit(status);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define CHECK_CUDA(ans)                                                                                                        \
+    { check_cuda_((ans), __FILE__, __LINE__); }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG)                 \
+    void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params,                      \
+                                                                                const bool configure_params) {                               \
+        launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(                    \
+            launch_params, configure_params);                                                                                                \
+    }                                                                                                                                        \
+    static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
+        ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define REGISTER_BWD_LAUNCHER(                                                                                                                  \
+    HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE)                      \
+    void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,                         \
+                                                                                const bool configure_params, const bool prenorm) {              \
+        launch_<WTYPE,                                                                                                                          \
+                ITYPE,                                                                                                                          \
+                RTYPE,                                                                                                                          \
+                OTYPE,                                                                                                                          \
+                CTYPE,                                                                                                                          \
+                uint32_t,                                                                                                                       \
+                HIDDEN_SIZE,                                                                                                                    \
+                CTAS_PER_ROW,                                                                                                                   \
+                WARPS_M,                                                                                                                        \
+                WARPS_N,                                                                                                                        \
+                BYTES_PER_LDG,                                                                                                                  \
+                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm);                                                              \
+    }                                                                                                                                           \
+    static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(    \
+        ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 operator+(const float2 & a, const float2 & b){
+    return {a.x + b.x, a.y + b.y};
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void operator+=(float2 & a, const float2 & b){
+    a.x += b.x;
+    a.y += b.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct Sum {
+    inline __device__ Sum(){}
+    inline __device__ T operator()(const T &a, const T &b){
+        return a + b;
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
+    return __shfl_xor_sync(uint32_t(-1), x, idx);
+}
+
+template<>
+inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
+    return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
+}
+
+template<typename T>
+inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
+    return __shfl_down_sync(uint32_t(-1), x, idx);
+}
+
+template<>
+inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
+    return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace layer_norm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct uint16 {
+    uint4 u;
+    uint4 v;
+    uint4 s;
+    uint4 t;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct uint8 {
+    uint4 u;
+    uint4 v;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int BYTES>
+struct BytesToType {};
+
+template<>
+struct BytesToType<64> {
+    using Type = uint16;
+    static_assert(sizeof(Type) == 64);
+};
+
+template<>
+struct BytesToType<32> {
+    using Type = uint8;
+    static_assert(sizeof(Type) == 32);
+};
+
+template<>
+struct BytesToType<16> {
+    using Type = uint4;
+    static_assert(sizeof(Type) == 16);
+};
+
+template<>
+struct BytesToType<8> {
+    using Type = uint64_t;
+    static_assert(sizeof(Type) == 8);
+};
+
+template<>
+struct BytesToType<4> {
+    using Type = uint32_t;
+    static_assert(sizeof(Type) == 4);
+};
+
+template<>
+struct BytesToType<2> {
+    using Type = uint16_t;
+    static_assert(sizeof(Type) == 2);
+};
+
+template<>
+struct BytesToType<1> {
+    using Type = uint8_t;
+    static_assert(sizeof(Type) == 1);
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct TypeToVec2 {};
+
+template<>
+struct TypeToVec2<float> {
+    using Type = float2;
+};
+
+template<>
+struct TypeToVec2<half> {
+    using Type = half2;
+};
+
+template<>
+struct TypeToVec2<nv_bfloat16> {
+    using Type = nv_bfloat162;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int INDEX>
+struct Get {
+    template<typename T, typename R>
+    static inline __device__ R of(const T &vec);
+};
+
+template<>
+template<typename T, typename R>
+inline __device__ R Get<0>::of(const T &vec) {
+    return vec.x;
+}
+
+template<>
+template<typename T, typename R>
+inline __device__ R Get<1>::of(const T &vec) {
+    return vec.y;
+}
+
+template<>
+template<typename T, typename R>
+inline __device__ R Get<2>::of(const T &vec) {
+    return vec.z;
+}
+
+template<>
+template<typename T, typename R>
+inline __device__ R Get<3>::of(const T &vec) {
+    return vec.w;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Src, typename Dst>
+struct Converter{
+    static inline __device__ Dst convert(const Src &from) {
+        return Dst(from);
+    }
+};
+
+template<>
+struct Converter<float2, half2>{
+    static inline __device__ half2 convert(const float2 &x) {
+        return __float22half2_rn(x);
+    }
+};
+
+template<>
+struct Converter<float2, nv_bfloat162>{
+    static inline __device__ nv_bfloat162 convert(const float2 &x) {
+#if __CUDA_ARCH__ >= 800
+        return __float22bfloat162_rn(x);
+#else
+        union {
+            nv_bfloat162 raw;
+            nv_bfloat16 x;
+            nv_bfloat16 y;
+        } tmp;
+        tmp.x = __float2bfloat16_rn(x.x);
+        tmp.y = __float2bfloat16_rn(x.y);
+        return tmp.raw;
+#endif
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct Zeros{
+    static inline __device__ T get() {
+        return T(0.f);
+    }
+};
+
+template<> 
+struct Zeros<float2>{
+    static inline __device__ float2 get() {
+        return make_float2(0.f, 0.f);
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Elt_type, uint32_t NUM_ELT>
+struct Vec {
+
+    enum { BYTES = NUM_ELT * sizeof(Elt_type) };
+
+    using Vec_type = typename BytesToType<BYTES>::Type;
+
+    using Alias_type = union {
+        Vec_type vec;
+        Elt_type elt[NUM_ELT];
+    };
+
+    Alias_type data;
+
+    template<typename S>
+    inline __device__ void to(Vec<S, NUM_ELT> &other) {
+        #pragma unroll
+        for( int it = 0; it < NUM_ELT; it++ ) {
+            other.data.elt[it] = S(this->data.elt[it]);
+        }
+    }
+
+    template<typename Op>
+    inline __device__ void assign(const Op &op) {
+        #pragma unroll
+        for( int it = 0; it < NUM_ELT; it++ ) {
+            this->data.elt[it] = op(it);
+        }
+    }
+
+    inline __device__ void load_from(const void *base_ptr, const size_t idx) {
+        this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
+    }
+
+    inline __device__ void store_to(void *base_ptr, const size_t idx) {
+        static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<uint32_t CTAS_PER_ROW>
+struct InterCTASync {
+
+    template<typename Params>
+    inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
+        : phase_counter_(0)
+        , b0_(params.barrier + bidm) // The barrier for this group of CTAs.
+        , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
+    {
+        // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
+    }
+
+    inline __device__ void spin_wait_(int *barrier, int step, int expected) {
+        asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
+        for( int found = -1; found != expected; ) {
+            asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
+        }
+    }
+
+    inline __device__ void sync(){
+        // ALL THREADS MUST ENTER!
+
+        // We switch barrier every iteration.
+        int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
+        // We decrement every other iteration.
+        bool dec = phase_counter_ & 0x2;
+        int step = dec ? -1 : 1;
+        int expected = dec ? 0 : CTAS_PER_ROW;
+        // There are only 4 phases: up/down for b0/b1.
+        phase_counter_ = (phase_counter_ + 1) & 0x3;
+
+        if( threadIdx.x == 0 ) {
+            spin_wait_(barrier, step, expected);
+        }
+        // CTA waits for thread 0
+        __syncthreads();
+    }
+
+    int phase_counter_;
+    int * b0_;
+    int * b1_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
+struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
+
+    using InterCTASync = InterCTASync<CTAS_PER_ROW>;
+    using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
+    using Type = typename Base::Type;
+
+    enum { SMEM_BYTES = Base::SMEM_BYTES };
+
+    enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
+    enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
+
+    // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
+    enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
+
+    template<typename Params>
+    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
+        : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) 
+        , inter_cta_(params, bidm, bidn)
+        , bidn_(bidn) // CTA id within the group.
+        , w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
+        , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
+    {
+    }
+
+    template<typename Op>
+    inline __device__ T allreduce(T data, Op &op) {
+        data = Base::reduce(data, op);
+        // We switch workspace every iteration.
+        T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
+
+        // Warp leaders 0 hold the CTA-local results.
+        if( this->warp_n_ == 0 && this->lane_ == 0 ) {
+            workspace[bidn_] = data;
+        }
+        inter_cta_.sync();
+        static_assert(CTAS_PER_ROW <= 32);
+        T total = Zeros<T>::get();
+        if(this->lane_ < CTAS_PER_ROW){
+            total = workspace[this->lane_];
+        }
+        total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
+
+        return total;
+    }
+
+    InterCTASync inter_cta_;
+
+    T *w0_;
+    T *w1_;
+    int bidn_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t WARPS_M>
+struct Reducer<T, 1, WARPS_M, 1> {
+
+    using Type = T;
+    enum { SMEM_BYTES = 0 };
+    enum { WORKSPACE_BYTES_PER_GROUP = 0 };
+
+    enum { THREADS_PER_WARP = 32 };
+
+    template<typename Params>
+    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) 
+        : warp_n_(warp_n)
+        , lane_(lane)
+    {
+    }
+
+    template<typename Op>
+    static inline __device__ T allreduce_(T data, Op &op) {
+        #pragma unroll
+        for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
+            data = op(data, warp_shuffle_xor(data, it));
+        }
+        return data;
+    }
+
+    template<typename Op>
+    inline __device__ T allreduce(T data, Op &op) {
+        return allreduce_(data, op);
+    }
+
+    template<typename Op>
+    inline __device__ T reduce(T data, Op &op){
+        // only lane 0 holds the result!
+        #pragma unroll
+        for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
+            data = op(data, warp_shuffle_down(data, it));
+        }  
+        return data;
+    }
+    int warp_n_;
+    int lane_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
+struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
+
+    using Base = Reducer<T, 1, WARPS_M, 1>;
+
+    using Type = T;
+
+    enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
+    enum { WORKSPACE_BYTES_PER_GROUP = 0 };
+
+    enum { THREADS_PER_WARP = 32 };
+
+    template<typename Params>
+    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) 
+        : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) 
+        , use0_(true)
+    {
+        smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
+        smem1_ = smem0_ + WARPS_M * WARPS_N;
+    }
+
+    template<typename Op>
+    inline __device__ T allreduce(T data, Op & op) {
+        T * smem = use0_ ? smem0_ : smem1_;
+        use0_ = !use0_;
+        data = Base::reduce(data, op);
+        if( this->lane_ == 0 ) {
+            smem[this->warp_n_] = data;
+        }
+        __syncthreads();
+        T out = Zeros<T>::get();
+        #pragma unroll
+        for( int it = 0; it < WARPS_N; it++ ) {
+            out = op(out, smem[it]);
+        }
+        return out;
+    }
+
+    template<typename Op>
+    inline __device__ T reduce(T data, Op &op) {
+        T * smem = use0_ ? smem0_ : smem1_;
+        use0_ = !use0_;
+        // only intra-CTA group leader holds the result!
+        data = Base::reduce(data, op);
+        if( this->lane_ == 0 ) {
+            smem[this->warp_n_] = data;
+        }
+        __syncthreads();
+        T out = Zeros<T>::get();
+        if( this->warp_n_ == 0 && this->lane_ == 0 ) {
+            #pragma unroll
+            for( int it = 0; it < WARPS_N; it++ ) {
+                out = op(out, smem[it]);
+            }
+        }
+        return out;
+    }
+
+    T * smem0_;
+    T * smem1_;
+    bool use0_;
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+ 
+template<typename T>
+inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
+    //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
+    int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
+    
+    #pragma unroll
+    for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
+        // Exchange
+        T n_b = warp_shuffle_down(n_a, step);
+        T m_b = warp_shuffle_down(m_a, step);
+        T m2_b = warp_shuffle_down(m2_a, step);
+
+        // Update
+        const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
+        const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
+        const T delta = m_a - m_b;
+        const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
+        const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
+
+        n_a = n_ab;
+        m_a = m_ab;
+        m2_a = m2_ab;
+    }
+    // Intra-warp broadcast (only lane 0 has valid stats).
+    m_a = __shfl_sync(uint32_t(-1), m_a, 0);
+    m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
+struct Stats {
+    // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
+
+    using InterCTASync = InterCTASync<CTAS_PER_ROW>;
+    using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
+    using stats_t = typename BlockStats::stats_t;
+
+    enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
+
+    template<typename Params>
+    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) 
+        : inter_cta_(params, bidm, bidn)
+        , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
+        , bidn_(bidn) // CTA id within the group.
+        , w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
+        , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
+        , warp_n_(warp_n)
+        , lane_(lane)
+    {
+    }
+
+    template<uint32_t N>
+    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
+        constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
+        // TODO rn is not really needed here..
+        constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
+        stats_t block_stats = block_stats_.compute(elts, block_rn);
+
+        stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
+
+        if( warp_n_ == 0 && lane_ == 0 ) {
+            workspace[bidn_] = block_stats;
+        }
+
+        // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
+        inter_cta_.sync();
+
+        T n = Zeros<T>::get();
+        T m = Zeros<T>::get();
+        T m2 = Zeros<T>::get();
+
+        // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
+        static_assert(CTAS_PER_ROW <= 32);
+
+        // Every warp does the final reduction locally. 
+        if( lane_ < CTAS_PER_ROW ) {
+            stats_t result = workspace[lane_];
+            n = ELTS_PER_ROW_PER_CTA;
+            m = layer_norm::Get<0>::of<stats_t, T>(result);
+            m2 = layer_norm::Get<1>::of<stats_t, T>(result);
+        }
+
+        warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
+
+        return { m, m2 };
+    }
+
+    InterCTASync inter_cta_;
+    BlockStats block_stats_;
+
+    stats_t *w0_;
+    stats_t *w1_;
+    int bidn_;
+    int warp_n_;
+    int lane_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
+struct Stats<T, 1, WARPS_M, WARPS_N> {
+
+    using WarpStats = Stats<T, 1, WARPS_M, 1>;
+    using stats_t = typename WarpStats::stats_t;
+
+    enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
+
+    template<typename Params>
+    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) 
+        : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
+        , use0_(true)
+    {
+        smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
+        smem1_ = smem0_ + WARPS_M * WARPS_N;
+    }
+
+    template<uint32_t N>
+    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
+        stats_t * smem = use0_ ? smem0_ : smem1_;
+        use0_ = !use0_;
+        // Compute warp local for all WARPS_N
+        constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
+        stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
+
+        //Each warp warp leader stores its stats
+        const auto warp_n = warp_stats_.reducer_.warp_n_;
+        const auto lane = warp_stats_.reducer_.lane_;
+        if( lane == 0 ) {
+            smem[warp_n] = warp_stats;
+        }
+        __syncthreads();
+
+        T n = Zeros<T>::get();
+        T m = Zeros<T>::get();
+        T m2 = Zeros<T>::get();
+
+        // Assume that there are less than 32 warps, such that we can finalize with a single warp
+        static_assert(WARPS_N <= 32);
+        if(lane < WARPS_N){
+            stats_t result = smem[lane];
+            n = N * THREADS_PER_WARP;
+            m = layer_norm::Get<0>::of<stats_t, T>(result);
+            m2 = layer_norm::Get<1>::of<stats_t, T>(result);
+        }
+
+        warp_chan_upd_dynamic(m, m2, n, WARPS_N);
+
+        return { m, m2 };
+    }
+    WarpStats warp_stats_;
+    stats_t * smem0_;
+    stats_t * smem1_;
+    bool use0_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, uint32_t WARPS_M>
+struct Stats<T, 1, WARPS_M, 1> {
+
+    using stats_t = typename TypeToVec2<T>::Type;
+    // The simple Warp reducer.
+    using Reducer = Reducer<T, 1, WARPS_M, 1>;
+
+    enum { SMEM_BYTES = 0 };
+
+    template<typename Params>
+    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) 
+        : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
+    {
+    }
+
+    template<uint32_t N>
+    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
+
+        auto sum = Sum<T>();
+
+        T m = Zeros<T>::get();
+        #pragma unroll
+        for( int it = 0; it < N; it++ ) {
+            m += elts[it];
+        }
+        m = reducer_.allreduce(m, sum) * rn;
+
+        T m2 = Zeros<T>::get();
+        #pragma unroll
+        for( int it = 0; it < N; it++ ) {
+            T diff = (elts[it] - m);
+            m2 += diff * diff;
+        }
+        m2 = reducer_.allreduce(m2, sum);
+
+        return {m, m2};
+    }
+
+    Reducer reducer_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace layer_norm

+ 143 - 0
csrc/layer_norm/setup.py

@@ -0,0 +1,143 @@
+# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
+from setuptools import setup, find_packages
+import subprocess
+
+import sys
+import warnings
+import os
+
+# ninja build does not work unless include_dirs are abs path
+this_dir = os.path.dirname(os.path.abspath(__file__))
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+
+def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
+    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
+    torch_binary_major = torch.version.cuda.split(".")[0]
+    torch_binary_minor = torch.version.cuda.split(".")[1]
+
+    print("\nCompiling cuda extensions with")
+    print(raw_output + "from " + cuda_dir + "/bin\n")
+
+    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
+        raise RuntimeError(
+            "Cuda extensions are being compiled with a version of Cuda that does "
+            "not match the version used to compile Pytorch binaries.  "
+            "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+            + "In some cases, a minor-version mismatch will not cause later errors:  "
+            "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
+            "You can try commenting out this check (at your own risk)."
+        )
+
+
+def raise_if_cuda_home_none(global_option: str) -> None:
+    if CUDA_HOME is not None:
+        return
+    raise RuntimeError(
+        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
+        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
+        "only images whose names contain 'devel' will provide nvcc."
+    )
+
+
+def append_nvcc_threads(nvcc_extra_args):
+    _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
+        return nvcc_extra_args + ["--threads", "4"]
+    return nvcc_extra_args
+
+
+if not torch.cuda.is_available():
+    # https://github.com/NVIDIA/apex/issues/486
+    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
+    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
+    print(
+        "\nWarning: Torch did not find available GPUs on this system.\n",
+        "If your intention is to cross-compile, this is not an error.\n"
+        "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
+        "Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
+        "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
+        "If you wish to cross-compile for a single specific architecture,\n"
+        'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
+    )
+    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+        _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+        if int(bare_metal_major) == 11:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
+            if int(bare_metal_minor) > 0:
+                os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
+        else:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
+
+print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
+
+cmdclass = {}
+ext_modules = []
+
+# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
+# See https://github.com/pytorch/pytorch/pull/70650
+generator_flag = []
+torch_dir = torch.__path__[0]
+if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
+    generator_flag = ["-DOLD_GENERATOR_PATH"]
+
+raise_if_cuda_home_none("--fast_layer_norm")
+# Check, if CUDA11 is installed for compute capability 8.0
+cc_flag = []
+# cc_flag.append("-gencode")
+# cc_flag.append("arch=compute_70,code=sm_70")
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_80,code=sm_80")
+
+ext_modules.append(
+    CUDAExtension(
+        name="dropout_layer_norm",
+        sources=[
+            "ln_api.cpp",
+            "ln_fwd_cuda_kernel.cu",
+            "ln_bwd_semi_cuda_kernel.cu",
+        ],
+        extra_compile_args={
+            "cxx": ["-O3"] + generator_flag,
+            "nvcc": append_nvcc_threads(
+                [
+                    "-O3",
+                    "-U__CUDA_NO_HALF_OPERATORS__",
+                    "-U__CUDA_NO_HALF_CONVERSIONS__",
+                    "-U__CUDA_NO_BFLOAT16_OPERATORS__",
+                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+                    "-U__CUDA_NO_BFLOAT162_OPERATORS__",
+                    "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
+                    "--expt-relaxed-constexpr",
+                    "--expt-extended-lambda",
+                    "--use_fast_math",
+                ]
+                + generator_flag
+                + cc_flag
+            ),
+        },
+        include_dirs=[this_dir],
+    )
+)
+
+setup(
+    name="dropout_layer_norm",
+    version="0.1",
+    description="Fused dropout + add + layer norm",
+    ext_modules=ext_modules,
+    cmdclass={"build_ext": BuildExtension} if ext_modules else {},
+)

+ 25 - 0
csrc/layer_norm/static_switch.h

@@ -0,0 +1,25 @@
+// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
+// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
+
+#pragma once
+
+/// @param COND       - a boolean expression to switch by
+/// @param CONST_NAME - a name given for the constexpr bool variable.
+/// @param ...       - code to execute for true and false
+///
+/// Usage:
+/// ```
+/// BOOL_SWITCH(flag, BoolConst, [&] {
+///     some_function<BoolConst>(...);
+/// });
+/// ```
+#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
+    [&] {                                                                            \
+        if (COND) {                                                                  \
+            constexpr bool CONST_NAME = true;                                        \
+            return __VA_ARGS__();                                                    \
+        } else {                                                                     \
+            constexpr bool CONST_NAME = false;                                       \
+            return __VA_ARGS__();                                                    \
+        }                                                                            \
+    }()

+ 6 - 0
csrc/xentropy/README.md

@@ -0,0 +1,6 @@
+This CUDA extension implements optimized cross-entropy loss, adapted from Apex's
+[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy).
+We make it work for bfloat16 and support in-place backward to save memory.
+```sh
+cd csrc/xentropy && pip install .
+```

+ 358 - 0
flash_attn/ops/fused_dense.py

@@ -0,0 +1,358 @@
+# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
+# We make it work with pytorch amp and with bfloat16.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+# import fused_dense_cuda  # from apex
+import fused_dense_lib as fused_dense_cuda
+# from src.ops.triton.triton_matmul import matmul_dgelu
+from flash_attn.ops.gelu_activation import gelu_bwd
+# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back
+
+
+# implements fused GEMM+bias in forward pass using mlp_cuda from apex
+class FusedDenseFuncTD(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, x, weight, bias):
+        if torch.is_autocast_enabled():
+            dtype = torch.get_autocast_gpu_dtype()
+            x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
+        x = x.contiguous()
+        weight = weight.contiguous()
+        bias = bias.contiguous()
+        ctx.save_for_backward(x, weight)
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
+        output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
+        return output.reshape(*batch_shape, output.shape[-1])
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output):
+        grad_output = grad_output.contiguous()
+        x, weight = ctx.saved_tensors
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        if ctx.needs_input_grad[0]:
+            grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
+                x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
+            )
+            grad_input = grad_input.reshape_as(x)
+        else:
+            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
+                x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
+            )
+            grad_input = None
+        # print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
+        return grad_input, grad_weight, grad_bias
+        # grad_input, grad_weight = None, None
+        # grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
+        # if ctx.needs_input_grad[0]:
+        #     grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
+        # if ctx.needs_input_grad[1]:
+        #     grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
+        # # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
+        # # will sum over the batch dimension to get grad_bias.
+        # return grad_input, grad_weight, grad_output
+
+
+fused_dense_function_td = FusedDenseFuncTD.apply
+
+
+class FusedDenseTD(nn.Linear):
+
+    def __init__(self, in_features: int, out_features: int, bias: bool = True,
+                 device=None, dtype=None) -> None:
+        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
+
+    def forward(self, x):
+        if x.is_cuda and self.bias is not None:
+            return fused_dense_function_td(x, self.weight, self.bias)
+        else:
+            return F.linear(x, self.weight, self.bias)
+
+
+class FusedDenseResidualFunc(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, x, weight, bias):
+        if torch.is_autocast_enabled():
+            dtype = torch.get_autocast_gpu_dtype()
+            x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
+        x = x.contiguous()
+        x = x.contiguous()
+        weight = weight.contiguous()
+        bias = bias.contiguous()
+        ctx.save_for_backward(x, weight)
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
+        output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
+        return output.reshape(*batch_shape, output.shape[-1]), x
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output, grad_input):
+        grad_output = grad_output.contiguous()
+        grad_input = grad_input.contiguous()
+        x, weight = ctx.saved_tensors
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
+            x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
+            grad_input.reshape(batch_dim, n)
+        )
+        return grad_input.reshape_as(x), grad_weight, grad_bias
+
+
+fused_dense_residual_function = FusedDenseResidualFunc.apply
+
+
+class FusedDenseResidual(nn.Linear):
+    """Similar to FusedDense, but we return both the output and the input.
+    This is so that in the backward pass, we can combine the input gradient from the residual branch
+    with the input gradient from the matrix multiply, without having to do a separate addition.
+    """
+
+    def forward(self, x):
+        if x.is_cuda and self.bias is not None:
+            return fused_dense_residual_function(x, self.weight, self.bias)
+        else:
+            return F.linear(x, self.weight, self.bias), x
+
+
+class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
+        """checkpoint_lvl:
+        0: no recomputation in the bwd
+        1: recompute gelu_out in the bwd
+        2: recompute gelu_in and gelu_out in the bwd
+        """
+        assert -1 <= heuristic <= 4
+        if torch.is_autocast_enabled():
+            dtype = torch.get_autocast_gpu_dtype()
+            x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
+                                                 for a in [x, weight1, bias1, weight2, bias2]]
+        assert checkpoint_lvl in [0, 1, 2]
+        x = x.contiguous()
+        weight1 = weight1.contiguous()
+        bias1 = bias1.contiguous()
+        weight2 = weight2.contiguous()
+        bias2 = bias2.contiguous()
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
+        # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
+        #     x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
+        # )
+        if heuristic == -1:
+            gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+            output1 = F.gelu(gelu_in, approximate='tanh')
+            # gelu_in = F.linear(x.reshape(batch_dim, n), weight1)  # This is before adding bias1
+            # with torch.jit.fuser('fuser2'):
+            #     output1 = bias_gelu(gelu_in, bias1)
+        else:
+            save_gelu_in = checkpoint_lvl != 2
+            output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
+                                                                  bias1, save_gelu_in, heuristic)
+            if save_gelu_in:
+                gelu_in = rest[0]
+        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
+        ctx.checkpoint_lvl = checkpoint_lvl
+        ctx.heuristic = heuristic
+        if checkpoint_lvl == 0:
+            ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
+        elif checkpoint_lvl == 1:
+            ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
+        elif checkpoint_lvl == 2:
+            ctx.save_for_backward(x, weight1, bias1, weight2)
+        return output2.reshape(*batch_shape, output2.shape[-1])
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output):
+        grad_output = grad_output.contiguous()
+        checkpoint_lvl = ctx.checkpoint_lvl
+        x, weight1, bias1, weight2, *rest = ctx.saved_tensors
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        if checkpoint_lvl == 0:
+            gelu_in, output1 = rest
+        elif checkpoint_lvl == 1:
+            gelu_in, = rest
+            output1 = F.gelu(gelu_in, approximate='tanh')
+        elif checkpoint_lvl == 2:
+            # bias1, = rest
+            if ctx.heuristic == -1:
+                gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+                output1 = F.gelu(gelu_in, approximate='tanh')
+            else:
+                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
+                                                                        weight1, bias1, True, ctx.heuristic)
+
+        if ctx.heuristic == -1:
+            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+            # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
+            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
+            # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
+            grad_output1 = grad_output @ weight2
+            with torch.jit.fuser('fuser2'):
+                grad_gelu = gelu_bwd(grad_output1, gelu_in)
+            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
+                x.reshape(batch_dim, n), weight1, grad_gelu
+            )
+            # with torch.jit.fuser('fuser2'):
+            #     grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
+            # grad_input = grad_gelu @ weight1
+            # grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
+            # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
+            #     x.reshape(batch_dim, n), weight1, grad_gelu
+            # )
+        else:
+            grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
+                x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
+                grad_output.reshape(batch_dim, grad_output.shape[-1]),
+                ctx.heuristic
+            )
+        # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+        # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
+        # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
+        # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
+        # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
+        #     x.reshape(batch_dim, n), weight1, grad_gelu
+        # )
+        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
+
+
+fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
+
+
+class FusedDenseGeluDenseTD(nn.Module):
+
+    def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
+                 checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
+        """
+        checkpoint_lvl (increasing lvl means slower but more memory saving):
+            0: no recomputation in the bwd
+            1: recompute gelu_out in the bwd
+            2: recompute gelu_in and gelu_out in the bwd
+        heuristic:
+            -1: don't fuse gemm + gelu (separate kernel)
+            0..4: use this heuristic for the algo section in the fused gemm + gelu
+        """
+        assert checkpoint_lvl in [0, 1, 2]
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        if out_features is None:
+            out_features = in_features
+        assert bias == True, "DenseGeluDense module without bias is currently not supported"
+        self.checkpoint_lvl = checkpoint_lvl
+        self.heuristic = heuristic
+        self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
+        self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
+
+    def forward(self, x):
+        return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
+                                                  self.fc2.weight, self.fc2.bias,
+                                                  self.checkpoint_lvl, self.heuristic)
+
+
+class FusedDenseResGeluDenseFunc(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
+        """checkpoint_lvl:
+        0: no recomputation in the bwd
+        1: recompute gelu_out in the bwd
+        2: recompute gelu_in and gelu_out in the bwd
+        """
+        assert -1 <= heuristic <= 4
+        if torch.is_autocast_enabled():
+            dtype = torch.get_autocast_gpu_dtype()
+            x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
+                                                 for a in [x, weight1, bias1, weight2, bias2]]
+        assert checkpoint_lvl in [0, 1, 2]
+        x = x.contiguous()
+        weight1 = weight1.contiguous()
+        bias1 = bias1.contiguous()
+        weight2 = weight2.contiguous()
+        bias2 = bias2.contiguous()
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
+        # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
+        #     x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
+        # )
+        # gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+        # output1 = F.gelu(gelu_in, approximate='tanh')
+        save_gelu_in = checkpoint_lvl != 2
+        output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
+                                                              bias1, save_gelu_in, heuristic)
+        if save_gelu_in:
+            gelu_in = rest[0]
+        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
+        ctx.checkpoint_lvl = checkpoint_lvl
+        ctx.heuristic = heuristic
+        if checkpoint_lvl == 0:
+            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
+        elif checkpoint_lvl == 1:
+            ctx.save_for_backward(x, weight1, weight2, gelu_in)
+        elif checkpoint_lvl == 2:
+            ctx.save_for_backward(x, weight1, weight2, bias1)
+        return output2.reshape(*batch_shape, output2.shape[-1]), x
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output, grad_input):
+        grad_output = grad_output.contiguous()
+        grad_input = grad_input.contiguous()
+        checkpoint_lvl = ctx.checkpoint_lvl
+        x, weight1, weight2, *rest = ctx.saved_tensors
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        if checkpoint_lvl == 0:
+            gelu_in, output1 = rest
+        elif checkpoint_lvl == 1:
+            gelu_in, = rest
+            output1 = F.gelu(gelu_in, approximate='tanh')
+        elif checkpoint_lvl == 2:
+            bias1, = rest
+            output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
+                                                                    weight1, bias1, True, ctx.heuristic)
+        grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
+            x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
+            grad_output.reshape(batch_dim, grad_output.shape[-1]),
+            grad_input.reshape(batch_dim, n),
+            ctx.heuristic
+        )
+        # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+        # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
+        # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
+        # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
+        # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
+        #     x.reshape(batch_dim, n), weight1, grad_gelu,
+        #     grad_input.reshape(batch_dim, n)
+        # )
+        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
+
+
+fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
+
+
+class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
+
+    def forward(self, x):
+        return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
+                                                      self.fc2.weight, self.fc2.bias,
+                                                      self.checkpoint_lvl, False, self.heuristic)

+ 82 - 0
flash_attn/ops/gelu_activation.py

@@ -0,0 +1,82 @@
+# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
+import math
+
+import torch
+from torch import nn
+
+
+# 1/sqrt(2*pi)-> 0.3989423
+# 1/sqrt(2)   -> 0.70710678
+# sqrt(2/pi)  -> 0.79788456
+
+# this function is tanh approximation of gelu
+# actual gelu is:
+# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
+@torch.jit.script
+def bias_gelu(y, bias):
+    x = bias + y
+    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
+
+# gradient of tanh approximation of gelu
+# gradient of actual gelu is:
+# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
+@torch.jit.script
+def bias_gelu_back(g, y, bias):
+    """Assume that y has shape (B, D) and bias has shape (D)
+    """
+    x = bias + y
+    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+    grad_y = ff * g
+    return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
+
+
+class GeLUFunction(torch.autograd.Function):
+    @staticmethod
+    # bias is an optional argument
+    def forward(ctx, input, bias):
+        ctx.save_for_backward(input, bias)
+        return bias_gelu(input, bias)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, bias = ctx.saved_tensors
+        tmp = bias_gelu_back(grad_output, input, bias)
+        return tmp, tmp
+
+
+bias_gelu_impl = GeLUFunction.apply
+
+# this function is tanh approximation of gelu
+# actual gelu is:
+# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
+@torch.jit.script
+def gelu_fwd(x):
+    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
+
+# gradient of tanh approximation of gelu
+# gradient of actual gelu is:
+# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
+@torch.jit.script
+def gelu_bwd(g, x):
+    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+    return (ff * g).to(dtype=x.dtype)
+
+
+class FastGeLUFunction(torch.autograd.Function):
+    @staticmethod
+    # bias is an optional argument
+    def forward(ctx, input):
+        ctx.save_for_backward(input)
+        return gelu_fwd(input)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        tmp = gelu_bwd(grad_output, input)
+        return tmp
+
+fast_gelu_impl = FastGeLUFunction.apply

+ 167 - 0
flash_attn/ops/layer_norm.py

@@ -0,0 +1,167 @@
+# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
+import torch
+from torch.nn import init
+
+# from apex._autocast_utils import _cast_if_autocast_enabled
+import dropout_layer_norm
+
+
+def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
+                                    residual_in_fp32):
+    """ Assume that arguments are contiguous
+    """
+    hidden_size = gamma.numel()
+    x0mat = x0.view((-1, hidden_size))
+    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
+    rowscale = rowscale.view(-1) if rowscale is not None else None
+    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
+        x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
+    )
+    # dmask is None if dropout_p == 0.0
+    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
+    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
+
+
+def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
+                                     has_residual):
+    """ Assume that arguments are contiguous
+    """
+    # dmask is None if dropout_p == 0.0
+    hidden_size = gamma.numel()
+    xmat = x.view((-1, hidden_size))
+    dzmat = dz.view(xmat.shape)
+    rowscale = rowscale.view(-1) if rowscale is not None else None
+    dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
+        dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+    )
+    # dx1mat is None if not has_residual
+    return dx0mat, dx1mat, dgamma, dbeta
+
+
+def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
+                                             dropout_p, has_residual):
+    """ Assume that arguments are contiguous
+    """
+    hidden_size = gamma.numel()
+    xmat = x.view((-1, hidden_size))
+    dzmat = dz.view(xmat.shape)
+    dxmat = dx.view(xmat.shape)
+    rowscale = rowscale.view(-1) if rowscale is not None else None
+    dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
+        dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+    )
+    return dx0mat, dx1mat, dgamma, dbeta
+
+
+class DropoutAddLayerNormFN(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
+                return_dmask=False):
+        x0 = x0.contiguous()
+        x1 = x1.contiguous() if x1 is not None else None
+        gamma = gamma.contiguous()
+        beta = beta.contiguous()
+        rowscale = rowscale.contiguous() if rowscale is not None else None
+        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
+            x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
+        )
+        ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
+        ctx.dropout_p = dropout_p
+        ctx.has_residual = x1 is not None
+        if not return_dmask:
+            return zmat.view(x0.shape)
+        else:
+            dmask = (dmask.view(x0.shape) if dropout_p > 0.
+                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
+            ctx.mark_non_differentiable(dmask)
+            return zmat.view(x0.shape), dmask
+
+    @staticmethod
+    def backward(ctx, dz, *args):
+        # assert dz.is_contiguous()
+        dz = dz.contiguous()  # this happens!
+        x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
+        dropout_p = ctx.dropout_p
+        has_residual = ctx.has_residual
+        dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
+            dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+        )
+        dx0 = dx0mat.view(x.shape)
+        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
+        return dx0, dx1, dgamma, dbeta, None, None, None, None, None
+
+
+class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
+                return_dmask=False):
+        x0 = x0.contiguous()
+        x1 = x1.contiguous() if x1 is not None else None
+        gamma = gamma.contiguous()
+        beta = beta.contiguous()
+        rowscale = rowscale.contiguous() if rowscale is not None else None
+        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
+            x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
+        )
+        ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
+        ctx.dropout_p = dropout_p
+        ctx.has_residual = x1 is not None
+        if not return_dmask:
+            return zmat.view(x0.shape), xmat.view(x0.shape)
+        else:
+            dmask = (dmask.view(x0.shape) if dropout_p > 0.
+                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
+            ctx.mark_non_differentiable(dmask)
+            return zmat.view(x0.shape), xmat.view(x0.shape), dmask
+
+    @staticmethod
+    def backward(ctx, dz, dx, *args):
+        # assert dz.is_contiguous()
+        dz = dz.contiguous()  # this happens!
+        dx = dx.contiguous()  # this happens!
+        x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
+        dropout_p = ctx.dropout_p
+        has_residual = ctx.has_residual
+        dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
+            dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
+        )
+        dx0 = dx0mat.view(x.shape)
+        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
+        return dx0, dx1, dgamma, dbeta, None, None, None, None, None
+
+
+def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
+                           prenorm=False, residual_in_fp32=False,
+                           return_dropout_mask=False):
+    """residual_in_fp32 only has an effect if x1 is None.
+    Otherwise residual dtype is x1.dtype.
+    """
+    args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
+            return_dropout_mask)
+    if not prenorm:
+        return DropoutAddLayerNormFN.apply(*args)
+    else:
+        return DropoutAddLayerNormPrenormFN.apply(*args)
+
+
+class DropoutAddLayerNorm(torch.nn.Module):
+    def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False,
+                 device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.prenorm = prenorm
+        self.p = p
+        self.epsilon = eps
+        self.residual_in_fp32 = residual_in_fp32
+        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
+        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        init.ones_(self.weight)
+        init.zeros_(self.bias)
+
+    def forward(self, x0, x1=None):
+        return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
+                                      self.p if self.training else 0.0, self.epsilon,
+                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

+ 267 - 0
tests/ops/test_dropout_layer_norm.py

@@ -0,0 +1,267 @@
+import math
+
+import torch
+import torch.nn.functional as F
+import pytest
+
+from einops import rearrange
+
+from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
+
+
+is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
+
+@pytest.mark.parametrize('has_rowscale', [True, False])
+# @pytest.mark.parametrize('has_rowscale', [True])
+@pytest.mark.parametrize('has_residual', [True, False])
+# @pytest.mark.parametrize('has_residual', [False])
+@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
+# @pytest.mark.parametrize('dropout_p', [0.0])
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+# @pytest.mark.parametrize('weight_dtype', [torch.float32])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
+@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+# @pytest.mark.parametrize('hidden_size', [768])
+def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
+                                     dropout_p, has_residual, has_rowscale):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    # Backward numerical error is high, and this case isn't used
+    if has_rowscale and not has_residual:
+        pytest.skip()
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 1e-4)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_residual:
+        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        x1 = x1_pt.detach().clone().requires_grad_()
+        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    else:
+        x1 = None
+    if has_rowscale:
+        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
+        survival_rate = 0.87
+        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
+        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
+        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
+    else:
+        rowscale = None
+        x0_scaled_pt = x0_pt
+        x0_scaled_ref = x0_ref
+    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
+    torch.nn.init.normal_(model_pt.weight)
+    torch.nn.init.normal_(model_pt.bias)
+    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
+    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+        model_ref.weight.copy_(model_pt.weight)
+        model_ref.bias.copy_(model_pt.bias)
+    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
+    out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
+                                        model.epsilon, rowscale=rowscale,
+                                        residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
+    assert out.dtype == input_dtype
+    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
+    if has_residual:
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
+    else:
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
+    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
+    out_ref = model_ref(residual_ref)
+    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
+
+    g = torch.randn_like(out) / batch_size
+    out_pt.backward(g)
+    out.backward(g)
+    out_ref.backward(g)
+    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
+    if has_residual:
+        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
+    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
+
+
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 1e-4)
+    dropout_p = 0.37
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 32
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+    x1 = x1_pt.detach().clone().requires_grad_()
+    x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
+    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
+    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+        model_ref.weight.copy_(model_pt.weight)
+        model_ref.bias.copy_(model_pt.bias)
+    model_pt.eval()
+    model.eval()
+    model_ref.eval()
+    out = model(x0, x1)
+    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
+    residual_ref = x0_ref + x1_ref
+    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
+    out_ref = model_ref(residual_ref)
+    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
+
+
+@pytest.mark.parametrize('has_rowscale', [True, False])
+@pytest.mark.parametrize('has_residual', [True, False])
+@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
+                                             dropout_p, has_residual, has_rowscale):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    # Backward numerical error is high, and this case isn't used
+    if has_rowscale and not has_residual:
+        pytest.skip()
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 2e-4)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    if has_residual:
+        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        x1 = x1_pt.detach().clone().requires_grad_()
+        x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    else:
+        x1 = None
+    if has_rowscale:
+        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
+        survival_rate = 0.87
+        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
+        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
+        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
+    else:
+        rowscale = None
+        x0_scaled_pt = x0_pt
+        x0_scaled_ref = x0_ref
+    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
+    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
+    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
+                                dtype=weight_dtype)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+        model_ref.weight.copy_(model_pt.weight)
+        model_ref.bias.copy_(model_pt.bias)
+    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
+    out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
+                                                  model.epsilon, rowscale=rowscale, prenorm=True,
+                                                  residual_in_fp32=residual_in_fp32,
+                                                  return_dropout_mask=True)
+    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
+    if has_residual:
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
+    else:
+        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
+        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
+    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
+    out_ref = model_ref(residual_ref)
+    assert out.dtype == input_dtype
+    assert residual.dtype == residual_dtype
+    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
+    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
+
+    g = torch.randn_like(out) / batch_size
+    (out_pt * F.sigmoid(residual_pt)).backward(g)
+    (out * F.sigmoid(residual)).backward(g)
+    (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
+    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
+    if has_residual:
+        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
+    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
+    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
+
+
+@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
+@pytest.mark.parametrize('input_dtype,residual_dtype',
+                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
+                          (torch.float32, torch.float32)]
+                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
+@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
+def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
+    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
+        pytest.skip()  # Not supported
+    device = 'cuda'
+    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
+    rtol, atol = (1e-3, 1e-4)
+    dropout_p = 0.37
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 32
+    seqlen = 512
+    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
+                        requires_grad=True)
+    x0 = x0_pt.detach().clone().requires_grad_()
+    x0_ref = x0_pt.detach().clone().float().requires_grad_()
+    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+    x1 = x1_pt.detach().clone().requires_grad_()
+    x1_ref = x1_pt.detach().clone().float().requires_grad_()
+    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
+    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
+                                dtype=weight_dtype)
+    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+        model_ref.weight.copy_(model_pt.weight)
+        model_ref.bias.copy_(model_pt.bias)
+    model_pt.eval()
+    model.eval()
+    model_ref.eval()
+    out, residual = model(x0, x1)
+    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
+    residual_ref = x0_ref + x1_ref
+    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
+    out_ref = model_ref(residual_ref)
+    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
+    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

+ 154 - 0
tests/ops/test_fused_dense.py

@@ -0,0 +1,154 @@
+import math
+
+import torch
+import torch.nn.functional as F
+import pytest
+
+from einops import rearrange
+
+from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
+from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
+
+
+@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('out_features', [1024, 4096])
+@pytest.mark.parametrize('in_features', [1024, 4096])
+def test_fused_linear_bias(in_features, out_features, dtype):
+    device = 'cuda'
+    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x = x_pt.detach().clone().requires_grad_()
+    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
+    model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+    out_pt = model_pt(x_pt)
+    out = model(x)
+    # with torch.no_grad():
+    #     out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
+    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
+
+    # If we don't divide by batch_size, the gradient gets a bit too large.
+    g = torch.randn_like(out) / 32
+    out_pt.backward(g)
+    out.backward(g)
+    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
+    # The error for d_weight and d_bias is quite a bit higher
+    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
+
+
+@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
+def test_fused_linear_bias_residual(in_features, out_features, dtype):
+    device = 'cuda'
+    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x = x_pt.detach().clone().requires_grad_()
+    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
+    model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
+    with torch.no_grad():
+        model.weight.copy_(model_pt.weight)
+        model.bias.copy_(model_pt.bias)
+    out_pt = model_pt(x_pt) + F.gelu(x_pt)  # Just add some random function of the residual x_pt
+    out, x_copy = model(x)
+    out = out + F.gelu(x_copy)
+    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
+
+    # If we don't divide by batch_size, the gradient gets a bit too large.
+    g = torch.randn_like(out) / 32
+    out_pt.backward(g)
+    out.backward(g)
+    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
+    # The error for d_weight and d_bias is quite a bit higher
+    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
+
+
+@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('heuristic', [1, -1])
+@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
+@pytest.mark.parametrize('out_features', [1024, 4096])
+@pytest.mark.parametrize('in_features', [1024, 4096])
+def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
+    device = 'cuda'
+    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x = x_pt.detach().clone().requires_grad_()
+    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
+    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
+    model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
+                                  checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
+                                  device=device, dtype=dtype)
+    with torch.no_grad():
+        model.fc1.weight.copy_(model_pt_fc1.weight)
+        model.fc1.bias.copy_(model_pt_fc1.bias)
+        model.fc2.weight.copy_(model_pt_fc2.weight)
+        model.fc2.bias.copy_(model_pt_fc2.bias)
+    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
+    out = model(x)
+    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
+
+    # If we don't divide by batch_size, the gradient gets a bit too large.
+    g = torch.randn_like(out) / 32
+    out_pt.backward(g)
+    out.backward(g)
+    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
+    # The error for d_weight and d_bias is quite a bit higher
+    assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
+    assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
+
+
+@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
+@pytest.mark.parametrize('out_features', [1024, 4096])
+@pytest.mark.parametrize('in_features', [1024, 4096])
+def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
+    device = 'cuda'
+    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x = x_pt.detach().clone().requires_grad_()
+    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
+    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
+    model = FusedDenseResGeluDense(in_features, out_features, in_features,
+                                   checkpoint_lvl=checkpoint_lvl,
+                                   device=device, dtype=dtype)
+    with torch.no_grad():
+        model.fc1.weight.copy_(model_pt_fc1.weight)
+        model.fc1.bias.copy_(model_pt_fc1.bias)
+        model.fc2.weight.copy_(model_pt_fc2.weight)
+        model.fc2.bias.copy_(model_pt_fc2.bias)
+    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
+    out, x_copy = model(x)
+    out = out + F.gelu(x_copy)
+    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
+
+    # If we don't divide by batch_size, the gradient gets a bit too large.
+    g = torch.randn_like(out) / 32
+    out_pt.backward(g)
+    out.backward(g)
+    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
+    # The error for d_weight and d_bias is quite a bit higher
+    assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
+    assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
+    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)