Prechádzať zdrojové kódy

Simplify FusedDense

Tri Dao 2 rokov pred
rodič
commit
e68ebbe89a

+ 72 - 243
csrc/fused_dense_lib/fused_dense.cpp

@@ -6,6 +6,8 @@
 
 #include <stdio.h>
 
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+
 // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...)                                \
@@ -24,14 +26,6 @@
     AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");            \
   }
 
-#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
-
-template <typename T>
-int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
-
-template <typename T>
-int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input,  bool residual, void *lt_workspace);
-
 template <typename T>
 int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
 
@@ -39,103 +33,34 @@ template <typename T>
 int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
 
 template <typename T>
-int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
-
-at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
-
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
-  int out_features = weight.size(0);
-
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
-
-  // create output/workspace tensor
-  auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
-  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
-  auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
-
-  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
-    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    auto result = linear_bias_forward_cuda<scalar_t>(
-        input,
-        w_ptr,
-        bias,
-        in_features,
-        batch_size,
-        out_features,
-        out,
-        //out.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
-        (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_bias_forward failed.")
-  });
-
-  return {out};
-}
-
-std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
-
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
-  int out_features = weight.size(0);
+int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
 
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
-
-  // create output/workspace tensor
-  auto opts = input.options();
-  auto d_weight = at::empty({out_features, in_features}, opts);
-#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
-  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
-#else
-  auto d_bias = at::empty({out_features}, opts);
-#endif
-  auto d_input = at::empty({batch_size, in_features}, opts);
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
-  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
-  auto lt_workspace = at::empty({1 << 22}, opts);
-
-  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
-    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    auto result = linear_bias_backward_cuda<scalar_t>(
-        input.data_ptr<scalar_t>(),
-        w_ptr,
-        d_output.data_ptr<scalar_t>(),
-        in_features,
-        batch_size,
-        out_features,
-        d_weight.data_ptr<scalar_t>(),
-        d_bias.data_ptr<scalar_t>(),
-        d_input.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
-        /*residual=*/false,
-        (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_bias_backward failed.")
-  });
-
-  return {d_input, d_weight, d_bias};
-}
-
-std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
-
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
+std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
 
+  int batch_size = input.size(0);
+  int in_features = input.size(1);
   int out_features = d_output.size(1);
 
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
+  TORCH_CHECK(input.dtype() == d_output.dtype());
+  TORCH_CHECK(input.is_cuda());
+  TORCH_CHECK(d_output.is_cuda());
+  TORCH_CHECK(input.is_contiguous());
+  TORCH_CHECK(d_output.is_contiguous());
+  CHECK_SHAPE(input, batch_size, in_features);
+  CHECK_SHAPE(d_output, batch_size, out_features);
 
   // create output/workspace tensor
   auto opts = input.options();
   auto d_weight = at::empty({out_features, in_features}, opts);
+  at::Tensor d_bias;
+  if (has_d_bias) {
 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
-  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
+    d_bias = d_output.view({-1, out_features}).sum(0, false);
 #else
-  auto d_bias = at::empty({out_features}, opts);
+    d_bias = at::empty({out_features}, opts);
 #endif
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
+  }
   // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
   auto lt_workspace = at::empty({1 << 22}, opts);
 
@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
         batch_size,
         out_features,
         d_weight.data_ptr<scalar_t>(),
-        d_bias.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
+        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
         (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
+    TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
   });
 
   return {d_weight, d_bias};
 }
 
-std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
-
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
-  int out_features = weight.size(0);
-
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
-
-  // create output/workspace tensor
-  auto opts = input.options();
-  auto d_weight = at::empty({out_features, in_features}, opts);
-#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
-  auto d_bias = d_output.view({-1, out_features}).sum(0, false);
-#else
-  auto d_bias = at::empty({out_features}, opts);
-#endif
-  CHECK_SHAPE(d_input, batch_size, in_features);
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
-  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
-  auto lt_workspace = at::empty({1 << 22}, opts);
-
-  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
-    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    auto result = linear_bias_backward_cuda<scalar_t>(
-        input.data_ptr<scalar_t>(),
-        w_ptr,
-        d_output.data_ptr<scalar_t>(),
-        in_features,
-        batch_size,
-        out_features,
-        d_weight.data_ptr<scalar_t>(),
-        d_bias.data_ptr<scalar_t>(),
-        d_input.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
-        /*residual=*/true,
-        (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
-  });
-
-  return {d_input, d_weight, d_bias};
-}
-
-std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
+std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
+                                            c10::optional<at::Tensor> bias_,
                                             bool save_gelu_in, int heuristic) {
 
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
+  int batch_size = input.size(0);
+  int in_features = input.size(1);
   int out_features = weight.size(0);
 
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
+  TORCH_CHECK(input.dtype() == weight.dtype());
+  TORCH_CHECK(input.is_cuda());
+  TORCH_CHECK(weight.is_cuda());
+  TORCH_CHECK(input.is_contiguous());
+  TORCH_CHECK(weight.is_contiguous());
+  CHECK_SHAPE(input, batch_size, in_features);
+  CHECK_SHAPE(weight, out_features, in_features);
+  if (bias_.has_value()) {
+    auto bias = bias_.value();
+    TORCH_CHECK(bias.dtype() == input.dtype());
+    TORCH_CHECK(bias.is_cuda());
+    TORCH_CHECK(bias.is_contiguous());
+    CHECK_SHAPE(bias, out_features);
+  }
 
   // create output/workspace tensor
   auto opts = input.options();
   auto output = at::empty({batch_size, out_features}, opts);
   at::Tensor gelu_in;
   if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
   // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
   auto lt_workspace = at::empty({1 << 22}, opts);
 
   DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
-    scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    scalar_t* b_ptr = bias.data_ptr<scalar_t>();
     auto result = linear_gelu_forward_cuda<scalar_t>(
         input.data_ptr<scalar_t>(),
-        w_ptr,
-        b_ptr,
+        weight.data_ptr<scalar_t>(),
+        bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
         in_features,
         batch_size,
         out_features,
         heuristic,
         output.data_ptr<scalar_t>(),
         save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
-       // reserved_space.data_ptr<scalar_t>(),
         (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
+    TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
   });
 
   std::vector<at::Tensor> result = {output};
@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
   return result;
 }
 
-std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
+std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
+  at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
+) {
 
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
-  int hidden_features = weight1.size(0);
-  int out_features = weight2.size(0);
-
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+  int batch_size = d_output.size(0);
+  int out_features = d_output.size(1);
+  int in_features = weight.size(1);
+
+  TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
+  TORCH_CHECK(weight.dtype() == d_output.dtype());
+  TORCH_CHECK(weight.dtype() == gelu_in.dtype());
+  TORCH_CHECK(weight.is_cuda());
+  TORCH_CHECK(d_output.is_cuda());
+  TORCH_CHECK(gelu_in.is_cuda());
+  TORCH_CHECK(weight.is_contiguous());
+  TORCH_CHECK(d_output.is_contiguous());
+  TORCH_CHECK(gelu_in.is_contiguous());
+  CHECK_SHAPE(weight, out_features, in_features);
+  CHECK_SHAPE(d_output, batch_size, out_features);
+  CHECK_SHAPE(gelu_in, batch_size, in_features);
 
   // create output/workspace tensor
-  auto opts = input.options();
-  auto d_weight1 = at::empty({hidden_features, in_features}, opts);
-  auto d_weight2 = at::empty({out_features, hidden_features}, opts);
-  auto d_bias1 = at::empty({hidden_features}, opts);
-  auto d_bias2 = at::empty({out_features}, opts);
+  auto opts = weight.options();
+  auto d_bias = at::empty({in_features}, opts);
   auto d_input = at::empty({batch_size, in_features}, opts);
-  auto d_output1 = at::empty({batch_size, hidden_features}, opts);
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
-  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
-  auto lt_workspace = at::empty({1 << 22}, opts);
-
-  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
-    //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
-    auto result = linear_gelu_linear_backward_cuda<scalar_t>(
-        input.data_ptr<scalar_t>(),
-        gelu_in.data_ptr<scalar_t>(),
-        output1.data_ptr<scalar_t>(),
-        weight1.data_ptr<scalar_t>(),
-        weight2.data_ptr<scalar_t>(),
-        d_output1.data_ptr<scalar_t>(),
-        d_output2.data_ptr<scalar_t>(),
-        in_features,
-        batch_size,
-        hidden_features,
-        out_features,
-        heuristic,
-        d_weight1.data_ptr<scalar_t>(),
-        d_weight2.data_ptr<scalar_t>(),
-        d_bias1.data_ptr<scalar_t>(),
-        d_bias2.data_ptr<scalar_t>(),
-        d_input.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
-        /*residual=*/false,
-        (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
-  });
-
-  return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
-}
-
-std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
-
-  auto batch_size = input.size(0);
-  auto in_features = input.size(1);
-
-  int hidden_features = weight1.size(0);
-  int out_features = weight2.size(0);
-
-  //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
-
-  // create output/workspace tensor
-  auto opts = input.options();
-  auto d_weight1 = at::empty({hidden_features, in_features}, opts);
-  auto d_weight2 = at::empty({out_features, hidden_features}, opts);
-  auto d_bias1 = at::empty({hidden_features}, opts);
-  auto d_bias2 = at::empty({out_features}, opts);
-  CHECK_SHAPE(d_input, batch_size, in_features);
-  auto d_output1 = at::empty({batch_size, hidden_features}, opts);
-  //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
   // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
   auto lt_workspace = at::empty({1 << 22}, opts);
 
-  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
-    //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
-    //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
-    auto result = linear_gelu_linear_backward_cuda<scalar_t>(
-        input.data_ptr<scalar_t>(),
+  DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
+    auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
+        weight.data_ptr<scalar_t>(),
+        d_output.data_ptr<scalar_t>(),
         gelu_in.data_ptr<scalar_t>(),
-        output1.data_ptr<scalar_t>(),
-        weight1.data_ptr<scalar_t>(),
-        weight2.data_ptr<scalar_t>(),
-        d_output1.data_ptr<scalar_t>(),
-        d_output2.data_ptr<scalar_t>(),
         in_features,
         batch_size,
-        hidden_features,
         out_features,
         heuristic,
-        d_weight1.data_ptr<scalar_t>(),
-        d_weight2.data_ptr<scalar_t>(),
-        d_bias1.data_ptr<scalar_t>(),
-        d_bias2.data_ptr<scalar_t>(),
         d_input.data_ptr<scalar_t>(),
-       // reserved_space.data_ptr<scalar_t>(),
-        /*residual=*/true,
+        d_bias.data_ptr<scalar_t>(),
         (void*) (lt_workspace.data_ptr<scalar_t>()));
-    TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
+    TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
   });
 
-  return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
+  return {d_input, d_bias};
 }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
-  m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
   m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
-  m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
   m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
-  m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
-  m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
+  m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
 }

+ 18 - 431
csrc/fused_dense_lib/fused_dense_cuda.cu

@@ -94,226 +94,6 @@ cublasStatus_t gemm_bias(
 
 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
 
-int gemm_bias_lt(
-    cublasLtHandle_t ltHandle,
-    cublasOperation_t transa,
-    cublasOperation_t transb,
-    int m,
-    int n,
-    int k,
-    const float *alpha, /* host pointer */
-    at::Half* A,
-    int lda,
-    at::Half* B,
-    int ldb,
-    const float *beta, /* host pointer */
-    at::Half* C,
-    int ldc,
-    void *workspace,
-    size_t workspaceSize,
-    cudaStream_t stream,
-    bool use_bias,
-    const void* bias) {
-  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
-
-  cublasLtMatmulDescOpaque_t operationDesc = {};
-  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
-  cublasLtMatmulPreferenceOpaque_t preference = {};
-
-  int returnedResults                             = 0;
-  cublasLtMatmulHeuristicResult_t heuristicResult = {};
-  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
-
-  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
-  // for details about defaults; here we just set the transforms for
-  // A and B.
-  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  if (use_bias) {
-    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
-    if (status != CUBLAS_STATUS_SUCCESS) {
-      goto CLEANUP;
-    }
-      epilogue = CUBLASLT_EPILOGUE_BIAS;
-  }
-
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
-  if (status != CUBLAS_STATUS_SUCCESS) {
-    goto CLEANUP;
-  }
-
-  // Create matrix descriptors. Not setting any extra attributes.
-  status = cublasLtMatrixLayoutInit(
-    &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatrixLayoutInit(
-    &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  // Create preference handle; In general, extra attributes can be
-  // used here to disable tensor ops or to make sure algo selected
-  // will work with badly aligned A, B, C. However, for simplicity
-  // here we assume A,B,C are always well aligned (e.g., directly
-  // come from cudaMalloc)
-  status = cublasLtMatmulPreferenceInit(&preference);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulPreferenceSetAttribute(
-    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  // We just need the best available heuristic to try and run matmul.
-  // There is no guarantee that this will work. For example, if A is
-  // badly aligned, you can request more (e.g. 32) algos and try to
-  // run them one by one until something works.
-  status = cublasLtMatmulAlgoGetHeuristic(
-    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  if (returnedResults == 0) {
-    status = CUBLAS_STATUS_NOT_SUPPORTED;
-    goto CLEANUP;
-  }
-  status = cublasLtMatmul(ltHandle,
-                          &operationDesc,
-                          alpha,
-                          A,
-                          &Adesc,
-                          B,
-                          &Bdesc,
-                          beta,
-                          C,
-                          &Cdesc,
-                          C,
-                          &Cdesc,
-                          //&heuristicResult.algo,
-                          NULL,
-                          workspace,
-                          workspaceSize,
-                          stream);
-
-CLEANUP:
-  // Descriptors are no longer needed as all GPU work was already
-  // enqueued.
-  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
-}
-
-int gemm_bias_lt(
-    cublasLtHandle_t ltHandle,
-    cublasOperation_t transa,
-    cublasOperation_t transb,
-    int m,
-    int n,
-    int k,
-    const float *alpha, /* host pointer */
-    at::BFloat16* A,
-    int lda,
-    at::BFloat16* B,
-    int ldb,
-    const float *beta, /* host pointer */
-    at::BFloat16* C,
-    int ldc,
-    void *workspace,
-    size_t workspaceSize,
-    cudaStream_t stream,
-    bool use_bias,
-    const void* bias) {
-  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
-
-  cublasLtMatmulDescOpaque_t operationDesc = {};
-  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
-  cublasLtMatmulPreferenceOpaque_t preference = {};
-
-  int returnedResults                             = 0;
-  cublasLtMatmulHeuristicResult_t heuristicResult = {};
-  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
-
-  // Create operation descriptor; see cublasLtMatmulDescAttributes_t
-  // for details about defaults; here we just set the transforms for
-  // A and B.
-  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  if (use_bias) {
-    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
-    if (status != CUBLAS_STATUS_SUCCESS) {
-      goto CLEANUP;
-    }
-      epilogue = CUBLASLT_EPILOGUE_BIAS;
-  }
-
-  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
-  if (status != CUBLAS_STATUS_SUCCESS) {
-    goto CLEANUP;
-  }
-
-  // Create matrix descriptors. Not setting any extra attributes.
-  status = cublasLtMatrixLayoutInit(
-    &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatrixLayoutInit(
-    &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  // Create preference handle; In general, extra attributes can be
-  // used here to disable tensor ops or to make sure algo selected
-  // will work with badly aligned A, B, C. However, for simplicity
-  // here we assume A,B,C are always well aligned (e.g., directly
-  // come from cudaMalloc)
-  status = cublasLtMatmulPreferenceInit(&preference);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-  status = cublasLtMatmulPreferenceSetAttribute(
-    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  // We just need the best available heuristic to try and run matmul.
-  // There is no guarantee that this will work. For example, if A is
-  // badly aligned, you can request more (e.g. 32) algos and try to
-  // run them one by one until something works.
-  status = cublasLtMatmulAlgoGetHeuristic(
-    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
-  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
-
-  if (returnedResults == 0) {
-    status = CUBLAS_STATUS_NOT_SUPPORTED;
-    goto CLEANUP;
-  }
-  status = cublasLtMatmul(ltHandle,
-                          &operationDesc,
-                          alpha,
-                          A,
-                          &Adesc,
-                          B,
-                          &Bdesc,
-                          beta,
-                          C,
-                          &Cdesc,
-                          C,
-                          &Cdesc,
-                          //&heuristicResult.algo,
-                          NULL,
-                          workspace,
-                          workspaceSize,
-                          stream);
-
-CLEANUP:
-  // Descriptors are no longer needed as all GPU work was already
-  // enqueued.
-  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
-}
-
 int gemm_bias_gelu_lt(
     cublasLtHandle_t ltHandle,
     cublasOperation_t transa,
@@ -332,7 +112,6 @@ int gemm_bias_gelu_lt(
     void *workspace,
     size_t workspaceSize,
     cudaStream_t stream,
-    bool use_bias,
     int heuristic,
     const void* gelu_in,
     const void* bias) {
@@ -363,12 +142,14 @@ int gemm_bias_gelu_lt(
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
   }
 
-  if (use_bias) {
+  if (bias != nullptr) {
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
     if (status != CUBLAS_STATUS_SUCCESS) {
       goto CLEANUP;
     }
     epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
+  } else {
+    epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
   }
 
   status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
@@ -453,7 +234,6 @@ int gemm_bias_gelu_lt(
     void *workspace,
     size_t workspaceSize,
     cudaStream_t stream,
-    bool use_bias,
     int heuristic,
     const void* gelu_in,
     const void* bias) {
@@ -484,12 +264,14 @@ int gemm_bias_gelu_lt(
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
   }
 
-  if (use_bias) {
+  if (bias != nullptr) {
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
     if (status != CUBLAS_STATUS_SUCCESS) {
       goto CLEANUP;
     }
     epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
+  } else {
+    epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
   }
 
   status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
@@ -574,7 +356,6 @@ int gemm_bgradb_lt(
     void *workspace,
     size_t workspaceSize,
     cudaStream_t stream,
-    bool use_bias,
     const void* bgrad) {
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
 
@@ -596,7 +377,7 @@ int gemm_bgradb_lt(
   status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
 
-  if (use_bias) {
+  if (bgrad != nullptr) {
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
     if (status != CUBLAS_STATUS_SUCCESS) {
       goto CLEANUP;
@@ -684,7 +465,6 @@ int gemm_bgradb_lt(
     void *workspace,
     size_t workspaceSize,
     cudaStream_t stream,
-    bool use_bias,
     const void* bgrad) {
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
 
@@ -706,7 +486,7 @@ int gemm_bgradb_lt(
   status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
 
-  if (use_bias) {
+  if (bgrad != nullptr) {
     status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
     if (status != CUBLAS_STATUS_SUCCESS) {
       goto CLEANUP;
@@ -1008,132 +788,6 @@ CLEANUP:
 
 #endif
 
-template <typename T>
-int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) {
-    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
-    // Get the stream from cublas handle to reuse for biasReLU kernel.
-    cudaStream_t stream;
-    cublasGetStream(handle, &stream);
-    const float alpha          = 1.0;
-    const float beta_zero       = 0.0;
-    const float beta_one       = 1.0;
-    int status = 1;
-#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
-    status = gemm_bias_lt(
-    (cublasLtHandle_t)handle,
-    CUBLAS_OP_T,
-    CUBLAS_OP_N,
-    out_features,
-    batch_size,
-    in_features,
-    &alpha, /* host pointer */
-    weight,
-    in_features,
-    input.data_ptr<T>(),
-    in_features,
-    &beta_zero, /* host pointer */
-    output.data_ptr<T>(),
-    out_features,
-    lt_workspace,
-    1 << 22,
-    stream,
-    true,
-    static_cast<const void*>(bias.data_ptr<T>()));
-#endif
-    if (status != 0){
-        output.copy_(bias);
-        status = gemm_bias(
-          handle,
-          CUBLAS_OP_T,
-          CUBLAS_OP_N,
-          out_features,
-          batch_size,
-          in_features,
-          &alpha,
-          weight,
-          in_features,
-          input.data_ptr<T>(),
-          in_features,
-          &beta_one,
-          output.data_ptr<T>(),
-          out_features);
-    }
-    return status;
-}
-
-    
-template <typename T>
-int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace) {
-    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
-    // Get the stream from cublas handle to reuse for biasReLU kernel.
-    cudaStream_t stream;
-    cublasGetStream(handle, &stream);
-    const float alpha          = 1.0;
-    const float beta_zero      = 0.0;
-    const float beta           = residual ? 1.0 : 0.0;
-    int status = 1;
-#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
-    status = gemm_bgradb_lt(
-    (cublasLtHandle_t)handle,
-    CUBLAS_OP_N,
-    CUBLAS_OP_T,
-    in_features,
-    out_features,
-    batch_size,
-    &alpha, /* host pointer */
-    input,
-    in_features,
-    d_output,
-    out_features,
-    &beta_zero, /* host pointer */
-    d_weight,
-    in_features,
-    lt_workspace,
-    1 << 22,
-    stream,
-    true,
-    static_cast<const void*>(d_bias));
-#endif
-    
-
-    if (status != 0){
-    
-        status = gemm_bias(
-          handle,
-          CUBLAS_OP_N,
-          CUBLAS_OP_T,
-          in_features,
-          out_features,
-          batch_size,
-          &alpha,
-          input,
-          in_features,
-          d_output,
-          out_features,
-          &beta_zero,
-          d_weight,
-          in_features);
-    }
-    
-    status = gemm_bias(
-      handle,
-      CUBLAS_OP_N,
-      CUBLAS_OP_N,
-      in_features,
-      batch_size,
-      out_features,
-      &alpha,
-      weight,
-      in_features,
-      d_output,
-      out_features,
-      &beta,
-      d_input,
-      in_features);
-    return status;
-
-}
-
 template <typename T>
 int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace) {
     cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
@@ -1162,13 +816,10 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz
     lt_workspace,
     1 << 22,
     stream,
-    true,
     static_cast<const void*>(d_bias));
 #endif
 
-
     if (status != 0){
-
         status = gemm_bias(
           handle,
           CUBLAS_OP_N,
@@ -1217,7 +868,6 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
     lt_workspace,
     1 << 22,
     stream,
-    true,
     heuristic,
     static_cast<const void*>(gelu_in),
     static_cast<const void*>(bias));
@@ -1228,109 +878,46 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
 }
 
 template <typename T>
-int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace) {
+int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace) {
     cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
     // Get the stream from cublas handle to reuse for biasReLU kernel.
     cudaStream_t stream;
     cublasGetStream(handle, &stream);
     const float alpha          = 1.0;
     const float beta_zero      = 0.0;
-    const float beta           = residual ? 1.0 : 0.0;
     int status = 1;
 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
-//wgrad for first gemm
-    status = gemm_bgradb_lt(
-    (cublasLtHandle_t)handle,
-    CUBLAS_OP_N,
-    CUBLAS_OP_T,
-    hidden_features,
-    out_features,
-    batch_size,
-    &alpha, /* host pointer */
-    output1,
-    hidden_features,
-    d_output2,
-    out_features,
-    &beta_zero, /* host pointer */
-    d_weight2,
-    hidden_features,
-    lt_workspace,
-    1 << 22,
-    stream,
-    true,
-    static_cast<const void*>(d_bias2));
-//dgrad for second GEMM
     status = gemm_dgelu_bgradb_lt(
     (cublasLtHandle_t)handle,
     CUBLAS_OP_N,
     CUBLAS_OP_N,
-    hidden_features,
+    in_features,
     batch_size,
     out_features,
     &alpha, /* host pointer */
-    weight2,
-    hidden_features,
-    d_output2,
+    weight,
+    in_features,
+    d_output,
     out_features,
     &beta_zero, /* host pointer */
-    d_output1,
-    hidden_features,
+    d_input,
+    in_features,
     lt_workspace,
     1 << 22,
     stream,
     heuristic,
     static_cast<const void*>(gelu_in),
-    static_cast<const void*>(d_bias1));
-//wgrad for the first GEMM
-    status = gemm_bias(
-      handle,
-      CUBLAS_OP_N,
-      CUBLAS_OP_T,
-      in_features,
-      hidden_features,
-      batch_size,
-      &alpha,
-      input,
-      in_features,
-      d_output1,
-      hidden_features,
-      &beta_zero,
-      d_weight1,
-      in_features);
-
-//dgrad for the first GEMM
-    status = gemm_bias(
-      handle,
-      CUBLAS_OP_N,
-      CUBLAS_OP_N,
-      in_features,
-      batch_size,
-      hidden_features,
-      &alpha,
-      weight1,
-      in_features,
-      d_output1,
-      hidden_features,
-      &beta,
-      d_input,
-      in_features);
+    static_cast<const void*>(d_bias));
 #endif
     return status;
 
 }
 
-
-template int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
-template int linear_bias_forward_cuda<at::BFloat16>(at::Tensor input, at::BFloat16 *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
-
-template int linear_bias_backward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, bool residual, void *lt_workspace) ;
-template int linear_bias_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, at::BFloat16 *d_input, bool residual, void *lt_workspace) ;
-
 template int linear_bias_wgrad_cuda<at::Half>(at::Half *input, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace) ;
 template int linear_bias_wgrad_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace) ;
 
 template int linear_gelu_forward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *bias, int in_features, int batch_size, int out_features, int heuristic, at::Half *output, at::Half *gelu_in, void *lt_workspace) ;
 template int linear_gelu_forward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *bias, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *output, at::BFloat16 *gelu_in, void *lt_workspace) ;
 
-template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, bool residual, void *lt_workspace);
-template int linear_gelu_linear_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *gelu_in, at::BFloat16 *output1, at::BFloat16 *weight1, at::BFloat16 *weight2, at::BFloat16 *d_output1, at::BFloat16 *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::BFloat16 *d_weight1, at::BFloat16 *d_weight2, at::BFloat16 *d_bias1, at::BFloat16 *d_bias2, at::BFloat16 *d_input, bool residual, void *lt_workspace);
+template int bias_gelu_linear_dgrad_bgrad_cuda<at::Half>(at::Half *weight, at::Half *d_output, at::Half *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace);
+template int bias_gelu_linear_dgrad_bgrad_cuda<at::BFloat16>(at::BFloat16 *weight, at::BFloat16 *d_output, at::BFloat16 *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace);

+ 4 - 4
flash_attn/layers/patch_embed.py

@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
 from einops import rearrange
 
 try:
-    from flash_attn.ops.fused_dense import FusedDenseTD
+    from flash_attn.ops.fused_dense import FusedDense
 except ImportError:
-    FusedDenseTD = None
+    FusedDense = None
 
 
 class PatchEmbed(nn.Module):
@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
         self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
         self.num_patches = self.grid_size[0] * self.grid_size[1]
         self.flatten = flatten
-        if fused_bias_fc and FusedDenseTD is None:
+        if fused_bias_fc and FusedDense is None:
             raise ImportError('fused_dense is not installed')
 
-        linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDenseTD
+        linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
         self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
 

+ 10 - 8
flash_attn/models/bert.py

@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
 from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
 
 try:
-    from flash_attn.ops.fused_dense import FusedDenseTD
+    from flash_attn.ops.fused_dense import FusedDense
 except ImportError:
-    FusedDenseTD = None
+    FusedDense = None
 
 try:
     from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
                           activation=partial(F.gelu, approximate=approximate),
                           return_residual=return_residual)
     else:
+        if FusedDenseGeluDense is None:
+            raise ImportError('fused_dense is not installed')
         mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
         # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
         if isinstance(mlp_checkpoint_lvl, Sequence):
@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
     def __init__(self, config):
         super().__init__()
         fused_bias_fc = getattr(config, 'fused_bias_fc', False)
-        if fused_bias_fc and FusedDenseTD is None:
+        if fused_bias_fc and FusedDense is None:
             raise ImportError('fused_dense is not installed')
-        linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
+        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
         self.dense = linear_cls(config.hidden_size, config.hidden_size)
         self.activation = nn.Tanh()
 
@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
     def __init__(self, config):
         super().__init__()
         fused_bias_fc = getattr(config, 'fused_bias_fc', False)
-        if fused_bias_fc and FusedDenseTD is None:
+        if fused_bias_fc and FusedDense is None:
             raise ImportError('fused_dense is not installed')
         self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
         if self.fused_dropout_add_ln and layer_norm is None:
             raise ImportError('dropout_add_layer_norm is not installed')
-        linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
+        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
         self.dense = linear_cls(config.hidden_size, config.hidden_size)
         approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
         self.transform_act_fn = nn.GELU(approximate=approximate)
@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
     def __init__(self, config):
         super().__init__()
         fused_bias_fc = getattr(config, 'fused_bias_fc', False)
-        if fused_bias_fc and FusedDenseTD is None:
+        if fused_bias_fc and FusedDense is None:
             raise ImportError('fused_dense is not installed')
-        linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
+        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
 
         self.transform = BertPredictionHeadTransform(config)
 

+ 2 - 0
flash_attn/models/gpt.py

@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
             assert layer_idx is not None
             mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
         if fused_dense_gelu_dense:
+            if FusedDenseGeluDense is None:
+                raise ImportError('fused_dense is not installed')
             mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
                               checkpoint_lvl=mlp_checkpoint_lvl)
         elif fused_dense_sqrelu_dense:

+ 7 - 6
flash_attn/modules/mha.py

@@ -21,9 +21,9 @@ except ImportError:
     flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
 
 try:
-    from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseResidual
+    from flash_attn.ops.fused_dense import FusedDense
 except ImportError:
-    FusedDenseTD, FusedDenseResidual = None, None
+    FusedDense = None
 
 try:
     from flash_attn.layers.rotary import RotaryEmbedding
@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
 
 
 class LinearResidual(nn.Linear):
-    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDenseResidual.
+    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
     """
 
     def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -311,10 +311,11 @@ class MHA(nn.Module):
             assert RotaryEmbedding is not None, 'rotary_emb is not installed'
             self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base)
 
-        if fused_bias_fc and FusedDenseTD is None:
+        if fused_bias_fc and FusedDense is None:
             raise ImportError('fused_dense is not installed')
-        linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
-        linear_resid_cls = LinearResidual if not fused_bias_fc else FusedDenseResidual
+        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
+        linear_resid_cls = (LinearResidual if not fused_bias_fc
+                            else partial(FusedDense, return_residual=True))
         if not self.cross_attn:
             if not self.return_residual:
                 self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)

+ 2 - 44
flash_attn/modules/mlp.py

@@ -5,11 +5,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 try:
-    from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
-    from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
+    from flash_attn.ops.fused_dense import FusedDenseGeluDense
 except ImportError:
-    fused_dense_gelu_dense_function_td = None
-    fused_dense_res_gelu_dense_function_td = None
+    FusedDenseGeluDense = None
 
 
 class Mlp(nn.Module):
@@ -30,43 +28,3 @@ class Mlp(nn.Module):
         y = self.activation(y)
         y = self.fc2(y)
         return y if not self.return_residual else (y, x)
-
-
-class FusedDenseGeluDense(nn.Module):
-
-    def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
-                 checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
-        """
-        checkpoint_lvl (increasing lvl means slower but more memory saving):
-            0: no recomputation in the bwd
-            1: recompute gelu_out in the bwd
-            2: recompute gelu_in and gelu_out in the bwd
-        heuristic:
-            -1: don't fuse gemm + gelu (separate kernel)
-            0..4: use this heuristic for the algo section in the fused gemm + gelu
-            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
-            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
-        return_residual: whether to return the input x along with the output. This is for
-            performance reason: for post-norm architecture, returning the input allows us
-            to fuse the backward of nn.Linear with the residual connection.
-        """
-        assert checkpoint_lvl in [0, 1, 2]
-        factory_kwargs = {'device': device, 'dtype': dtype}
-        super().__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        assert bias == True, "DenseGeluDense module without bias is currently not supported"
-        assert (fused_dense_gelu_dense_function_td is not None
-                and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
-        self.checkpoint_lvl = checkpoint_lvl
-        self.heuristic = heuristic
-        self.return_residual = return_residual
-        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
-        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
-
-    def forward(self, x):
-        assert x.is_cuda
-        fn = (fused_dense_gelu_dense_function_td if not self.return_residual
-              else fused_dense_res_gelu_dense_function_td)
-        return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
-                  self.checkpoint_lvl, self.heuristic)

+ 146 - 233
flash_attn/ops/fused_dense.py

@@ -1,9 +1,11 @@
 # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
 # We make it work with pytorch amp and with bfloat16.
+from typing import Optional
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch import Tensor
 from torch.cuda.amp import custom_bwd, custom_fwd
 
 # import fused_dense_cuda  # from apex
@@ -11,126 +13,84 @@ import fused_dense_lib as fused_dense_cuda
 from flash_attn.ops.gelu_activation import gelu_bwd
 
 
-# implements fused GEMM+bias in forward pass using mlp_cuda from apex
-class FusedDenseFuncTD(torch.autograd.Function):
+class FusedDenseFunc(torch.autograd.Function):
 
     @staticmethod
     @custom_fwd
-    def forward(ctx, x, weight, bias):
+    def forward(ctx, x, weight, bias, return_residual=False):
         if torch.is_autocast_enabled():
             dtype = torch.get_autocast_gpu_dtype()
-            x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
+            x, weight = [a.to(dtype=dtype) for a in [x, weight]]
+            bias = bias.to(dtype=dtype) if bias is not None else None
+        ctx.return_residual = return_residual
         x = x.contiguous()
         weight = weight.contiguous()
-        bias = bias.contiguous()
         ctx.save_for_backward(x, weight)
         batch_shape, n = x.shape[:-1], x.shape[-1]
         batch_dim = batch_shape.numel()
         assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
-        output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
-        return output.reshape(*batch_shape, output.shape[-1])
+        output = F.linear(x, weight, bias)
+        return output if not return_residual else (output, x)
 
     @staticmethod
     @custom_bwd
-    def backward(ctx, grad_output):
+    def backward(ctx, grad_output, *args):
         grad_output = grad_output.contiguous()
+        if ctx.return_residual:
+            grad_input, = args
+            grad_input = grad_input.contiguous()
         x, weight = ctx.saved_tensors
         batch_shape, n = x.shape[:-1], x.shape[-1]
         batch_dim = batch_shape.numel()
-        if ctx.needs_input_grad[0]:
-            grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
-                x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
+        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+        if ctx.needs_input_grad[1]:
+            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
+                x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2]
             )
+        else:
+            grad_weight = None
+            grad_bias = grad_output if ctx.needs_input_grad[2] else None
+        if ctx.needs_input_grad[0]:
+            if not ctx.return_residual:
+                grad_input = F.linear(grad_output, weight.t())
+            else:
+                grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight)
             grad_input = grad_input.reshape_as(x)
         else:
-            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
-                x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
-            )
             grad_input = None
-        # print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
-        return grad_input, grad_weight, grad_bias
-        # grad_input, grad_weight = None, None
-        # grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
-        # if ctx.needs_input_grad[0]:
-        #     grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
-        # if ctx.needs_input_grad[1]:
-        #     grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
-        # # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
-        # # will sum over the batch dimension to get grad_bias.
-        # return grad_input, grad_weight, grad_output
+        return grad_input, grad_weight, grad_bias, None
 
 
-fused_dense_function_td = FusedDenseFuncTD.apply
+def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
+                     return_residual: bool = False):
+    batch_dim = x.shape[:-1].numel()
+    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
+                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
+    if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
+        and dtype_eligible):
+        return FusedDenseFunc.apply(x, weight, bias, return_residual)
+    else:
+        out = F.linear(x, weight, bias)
+        return out if not return_residual else (out, x)
 
 
-class FusedDenseTD(nn.Linear):
+class FusedDense(nn.Linear):
 
     def __init__(self, in_features: int, out_features: int, bias: bool = True,
-                 device=None, dtype=None) -> None:
+                 return_residual: bool = False, device=None, dtype=None) -> None:
         super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
+        self.return_residual = return_residual
 
     def forward(self, x):
-        if x.is_cuda and self.bias is not None:
-            return fused_dense_function_td(x, self.weight, self.bias)
-        else:
-            return F.linear(x, self.weight, self.bias)
+        return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual)
 
 
-class FusedDenseResidualFunc(torch.autograd.Function):
+class FusedDenseGeluDenseFunc(torch.autograd.Function):
 
     @staticmethod
     @custom_fwd
-    def forward(ctx, x, weight, bias):
-        if torch.is_autocast_enabled():
-            dtype = torch.get_autocast_gpu_dtype()
-            x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
-        x = x.contiguous()
-        x = x.contiguous()
-        weight = weight.contiguous()
-        bias = bias.contiguous()
-        ctx.save_for_backward(x, weight)
-        batch_shape, n = x.shape[:-1], x.shape[-1]
-        batch_dim = batch_shape.numel()
-        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
-        output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
-        return output.reshape(*batch_shape, output.shape[-1]), x
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, grad_output, grad_input):
-        grad_output = grad_output.contiguous()
-        grad_input = grad_input.contiguous()
-        x, weight = ctx.saved_tensors
-        batch_shape, n = x.shape[:-1], x.shape[-1]
-        batch_dim = batch_shape.numel()
-        grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
-            x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
-            grad_input.reshape(batch_dim, n)
-        )
-        return grad_input.reshape_as(x), grad_weight, grad_bias
-
-
-fused_dense_residual_function = FusedDenseResidualFunc.apply
-
-
-class FusedDenseResidual(nn.Linear):
-    """Similar to FusedDense, but we return both the output and the input.
-    This is so that in the backward pass, we can combine the input gradient from the residual branch
-    with the input gradient from the matrix multiply, without having to do a separate addition.
-    """
-
-    def forward(self, x):
-        if x.is_cuda and self.bias is not None:
-            return fused_dense_residual_function(x, self.weight, self.bias)
-        else:
-            return F.linear(x, self.weight, self.bias), x
-
-
-class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd
-    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
+    def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False,
+                checkpoint_lvl=0, heuristic=0):
         """checkpoint_lvl:
         0: no recomputation in the bwd
         1: recompute gelu_out in the bwd
@@ -139,49 +99,53 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
         assert -1 <= heuristic <= 4
         if torch.is_autocast_enabled():
             dtype = torch.get_autocast_gpu_dtype()
-            x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
-                                                 for a in [x, weight1, bias1, weight2, bias2]]
+            x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
+            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
+            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
+        if not save_gelu_in:
+            checkpoint_lvl = 2
         assert checkpoint_lvl in [0, 1, 2]
+        ctx.return_residual = return_residual
         x = x.contiguous()
         weight1 = weight1.contiguous()
-        bias1 = bias1.contiguous()
+        bias1 = bias1.contiguous() if bias1 is not None else None
         weight2 = weight2.contiguous()
-        bias2 = bias2.contiguous()
+        bias2 = bias2.contiguous() if bias2 is not None else None
         batch_shape, n = x.shape[:-1], x.shape[-1]
         batch_dim = batch_shape.numel()
         assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
-        # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
-        #     x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
-        # )
         if heuristic == -1:
-            gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+            gelu_in = F.linear(x, weight1, bias1)
             output1 = F.gelu(gelu_in, approximate='tanh')
             # gelu_in = F.linear(x.reshape(batch_dim, n), weight1)  # This is before adding bias1
             # with torch.jit.fuser('fuser2'):
             #     output1 = bias_gelu(gelu_in, bias1)
         else:
-            save_gelu_in = checkpoint_lvl != 2
             output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
                                                                   bias1, save_gelu_in, heuristic)
             if save_gelu_in:
                 gelu_in = rest[0]
-        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
+        output2 = F.linear(output1, weight2, bias2)
         ctx.checkpoint_lvl = checkpoint_lvl
         ctx.heuristic = heuristic
         if checkpoint_lvl == 0:
-            ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
+            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
         elif checkpoint_lvl == 1:
-            ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
+            ctx.save_for_backward(x, weight1, weight2, gelu_in)
         elif checkpoint_lvl == 2:
-            ctx.save_for_backward(x, weight1, bias1, weight2)
-        return output2.reshape(*batch_shape, output2.shape[-1])
+            ctx.save_for_backward(x, weight1, weight2, bias1)
+        output2 = output2.reshape(*batch_shape, output2.shape[-1])
+        return output2 if not return_residual else (output2, x)
 
     @staticmethod
     @custom_bwd
-    def backward(ctx, grad_output):
+    def backward(ctx, grad_output, *args):
         grad_output = grad_output.contiguous()
         checkpoint_lvl = ctx.checkpoint_lvl
-        x, weight1, bias1, weight2, *rest = ctx.saved_tensors
+        if ctx.return_residual:
+            grad_input, = args
+            grad_input = grad_input.contiguous()
+        x, weight1, weight2, *rest = ctx.saved_tensors
         batch_shape, n = x.shape[:-1], x.shape[-1]
         batch_dim = batch_shape.numel()
         if checkpoint_lvl == 0:
@@ -190,55 +154,88 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
             gelu_in, = rest
             output1 = F.gelu(gelu_in, approximate='tanh')
         elif checkpoint_lvl == 2:
-            # bias1, = rest
+            bias1, = rest
             if ctx.heuristic == -1:
-                gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+                gelu_in = F.linear(x, weight1, bias1)
                 output1 = F.gelu(gelu_in, approximate='tanh')
             else:
-                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
-                                                                        weight1, bias1, True, ctx.heuristic)
-
+                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
+                    x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic
+                )
+
+        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+        output1 = output1.reshape(batch_dim, output1.shape[-1])
+        gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
+        if ctx.needs_input_grad[3]:
+            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
+                output1, grad_output, ctx.needs_input_grad[4]
+            )
+        else:
+            grad_weight2 = None
+            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
         if ctx.heuristic == -1:
-            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
-            # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
-            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
             # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
-            grad_output1 = grad_output @ weight2
+            grad_output1 = F.linear(grad_output, weight2.t())
             with torch.jit.fuser('fuser2'):
                 grad_gelu = gelu_bwd(grad_output1, gelu_in)
-            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
-                x.reshape(batch_dim, n), weight1, grad_gelu
-            )
-            # with torch.jit.fuser('fuser2'):
-            #     grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
-            # grad_input = grad_gelu @ weight1
-            # grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
-            # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
-            #     x.reshape(batch_dim, n), weight1, grad_gelu
-            # )
+            if ctx.needs_input_grad[1]:
+                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
+                    x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2]
+                )
+            else:
+                grad_weight1 = None
+                grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
         else:
-            grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
-                x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
-                grad_output.reshape(batch_dim, grad_output.shape[-1]),
-                ctx.heuristic
+            # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
+            # just compute gelu grad
+            grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
+                weight2, grad_output, gelu_in, ctx.heuristic
             )
-        # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
-        # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
-        # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
-        # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
-        # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
-        #     x.reshape(batch_dim, n), weight1, grad_gelu
-        # )
-        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
-
-
-fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
+            if not ctx.needs_input_grad[2]:
+                grad_bias1 = None
+            if ctx.needs_input_grad[1]:
+                grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t())
+            else:
+                grad_weight1 = None
+        if ctx.needs_input_grad[0]:
+            if not ctx.return_residual:
+                grad_input = F.linear(grad_gelu, weight1.t())
+            else:
+                grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1)
+            grad_input = grad_input.reshape_as(x)
+        else:
+            grad_input = None
+        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None
+
+
+def fused_dense_gelu_dense_func(
+    x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
+    bias2: Optional[Tensor] = None,
+    save_gelu_in: bool = True, return_residual: bool = False,
+    checkpoint_lvl: int = 0, heuristic: int = 0
+):
+    batch_dim = x.shape[:-1].numel()
+    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
+                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
+    if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
+        and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
+        and dtype_eligible):
+        return FusedDenseGeluDenseFunc.apply(
+            x, weight1, bias1, weight2, bias2,
+            save_gelu_in, return_residual, checkpoint_lvl, heuristic
+        )
+    else:
+        gelu_in = F.linear(x, weight1, bias1)
+        output1 = F.gelu(gelu_in, approximate='tanh')
+        output2 = F.linear(output1, weight2, bias2)
+        return output2 if not return_residual else (output2, x)
 
 
-class FusedDenseGeluDenseTD(nn.Module):
+class FusedDenseGeluDense(nn.Module):
 
-    def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
-                 checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
+    def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
+                 bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
+                 device=None, dtype=None):
         """
         checkpoint_lvl (increasing lvl means slower but more memory saving):
             0: no recomputation in the bwd
@@ -247,110 +244,26 @@ class FusedDenseGeluDenseTD(nn.Module):
         heuristic:
             -1: don't fuse gemm + gelu (separate kernel)
             0..4: use this heuristic for the algo section in the fused gemm + gelu
+            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
+            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
+        return_residual: whether to return the input x along with the output. This is for
+            performance reason: for post-norm architecture, returning the input allows us
+            to fuse the backward of nn.Linear with the residual connection.
         """
         assert checkpoint_lvl in [0, 1, 2]
         factory_kwargs = {'device': device, 'dtype': dtype}
         super().__init__()
         if out_features is None:
             out_features = in_features
-        assert bias == True, "DenseGeluDense module without bias is currently not supported"
+        self.return_residual = return_residual
         self.checkpoint_lvl = checkpoint_lvl
         self.heuristic = heuristic
-        self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
-        self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
     def forward(self, x):
-        return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
-                                                  self.fc2.weight, self.fc2.bias,
-                                                  self.checkpoint_lvl, self.heuristic)
-
-
-class FusedDenseResGeluDenseFunc(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd
-    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
-        """checkpoint_lvl:
-        0: no recomputation in the bwd
-        1: recompute gelu_out in the bwd
-        2: recompute gelu_in and gelu_out in the bwd
-        """
-        assert -1 <= heuristic <= 4
-        if torch.is_autocast_enabled():
-            dtype = torch.get_autocast_gpu_dtype()
-            x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
-                                                 for a in [x, weight1, bias1, weight2, bias2]]
-        assert checkpoint_lvl in [0, 1, 2]
-        x = x.contiguous()
-        weight1 = weight1.contiguous()
-        bias1 = bias1.contiguous()
-        weight2 = weight2.contiguous()
-        bias2 = bias2.contiguous()
-        batch_shape, n = x.shape[:-1], x.shape[-1]
-        batch_dim = batch_shape.numel()
-        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
-        # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
-        #     x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
-        # )
-        # gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
-        # output1 = F.gelu(gelu_in, approximate='tanh')
-        save_gelu_in = checkpoint_lvl != 2
-        output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
-                                                              bias1, save_gelu_in, heuristic)
-        if save_gelu_in:
-            gelu_in = rest[0]
-        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
-        ctx.checkpoint_lvl = checkpoint_lvl
-        ctx.heuristic = heuristic
-        if checkpoint_lvl == 0:
-            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
-        elif checkpoint_lvl == 1:
-            ctx.save_for_backward(x, weight1, weight2, gelu_in)
-        elif checkpoint_lvl == 2:
-            ctx.save_for_backward(x, weight1, weight2, bias1)
-        return output2.reshape(*batch_shape, output2.shape[-1]), x
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, grad_output, grad_input):
-        grad_output = grad_output.contiguous()
-        grad_input = grad_input.contiguous()
-        checkpoint_lvl = ctx.checkpoint_lvl
-        x, weight1, weight2, *rest = ctx.saved_tensors
-        batch_shape, n = x.shape[:-1], x.shape[-1]
-        batch_dim = batch_shape.numel()
-        if checkpoint_lvl == 0:
-            gelu_in, output1 = rest
-        elif checkpoint_lvl == 1:
-            gelu_in, = rest
-            output1 = F.gelu(gelu_in, approximate='tanh')
-        elif checkpoint_lvl == 2:
-            bias1, = rest
-            output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
-                                                                    weight1, bias1, True, ctx.heuristic)
-        grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
-            x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
-            grad_output.reshape(batch_dim, grad_output.shape[-1]),
-            grad_input.reshape(batch_dim, n),
-            ctx.heuristic
+        return fused_dense_gelu_dense_func(
+            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
+            save_gelu_in=self.training, return_residual=self.return_residual,
+            checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic
         )
-        # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
-        # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
-        # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
-        # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
-        # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
-        #     x.reshape(batch_dim, n), weight1, grad_gelu,
-        #     grad_input.reshape(batch_dim, n)
-        # )
-        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
-
-
-fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
-
-
-class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
-
-    def forward(self, x):
-        return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
-                                                      self.fc2.weight, self.fc2.bias,
-                                                      self.checkpoint_lvl, False, self.heuristic)

+ 57 - 94
tests/ops/test_fused_dense.py

@@ -6,29 +6,44 @@ import pytest
 
 from einops import rearrange
 
-from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
-from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
+from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
 
 
 @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('return_residual', [False, True])
+@pytest.mark.parametrize('has_bias', [True, False])
 @pytest.mark.parametrize('out_features', [1024, 4096])
 @pytest.mark.parametrize('in_features', [1024, 4096])
-def test_fused_linear_bias(in_features, out_features, dtype):
+def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
     device = 'cuda'
     rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
     # set seed
     torch.random.manual_seed(0)
     batch_size = 8
     seqlen = 512
-    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
+                       requires_grad=True)
     x = x_pt.detach().clone().requires_grad_()
-    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
-    model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
+    model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
+    model = FusedDense(in_features, out_features, bias=has_bias, return_residual=return_residual,
+                       device=device, dtype=dtype)
     with torch.no_grad():
         model.weight.copy_(model_pt.weight)
-        model.bias.copy_(model_pt.bias)
+        if has_bias:
+            model.bias.copy_(model_pt.bias)
     out_pt = model_pt(x_pt)
-    out = model(x)
+    if not return_residual:
+        out = model(x)
+    else:
+        out, x_copy = model(x)
+        x_copy = (x_copy[..., :out_features] if out_features < in_features
+                  else F.pad(x_copy, (0, out_features - in_features)))
+        x_pt_copy = (x_pt[..., :out_features] if out_features < in_features
+                     else F.pad(x_pt, (0, out_features - in_features)))
+        # Just add some random function of the residual
+        out_pt = out_pt + F.gelu(x_pt_copy)
+        out = out + F.gelu(x_copy)
+
     # with torch.no_grad():
     #     out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
     assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
     assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
     # The error for d_weight and d_bias is quite a bit higher
     assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
-
-
-@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
-def test_fused_linear_bias_residual(in_features, out_features, dtype):
-    device = 'cuda'
-    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 8
-    seqlen = 512
-    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
-    x = x_pt.detach().clone().requires_grad_()
-    model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
-    model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
-    with torch.no_grad():
-        model.weight.copy_(model_pt.weight)
-        model.bias.copy_(model_pt.bias)
-    out_pt = model_pt(x_pt) + F.gelu(x_pt)  # Just add some random function of the residual x_pt
-    out, x_copy = model(x)
-    out = out + F.gelu(x_copy)
-    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
-
-    # If we don't divide by batch_size, the gradient gets a bit too large.
-    g = torch.randn_like(out) / 32
-    out_pt.backward(g)
-    out.backward(g)
-    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
-    # The error for d_weight and d_bias is quite a bit higher
-    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
+    if has_bias:
+        assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
 
 
 @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('heuristic', [1, -1])
+@pytest.mark.parametrize('heuristic', [0, -1])
 @pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
+@pytest.mark.parametrize('return_residual', [False, True])
+@pytest.mark.parametrize('has_bias2', [True, False])
+@pytest.mark.parametrize('has_bias1', [True, False])
 @pytest.mark.parametrize('out_features', [1024, 4096])
 @pytest.mark.parametrize('in_features', [1024, 4096])
-def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
+def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual,
+                                checkpoint_lvl, heuristic, dtype):
     device = 'cuda'
-    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
+    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
     # set seed
     torch.random.manual_seed(0)
     batch_size = 8
     seqlen = 512
-    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
+    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
+                       requires_grad=True)
     x = x_pt.detach().clone().requires_grad_()
-    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
-    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
-    model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
-                                  checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
-                                  device=device, dtype=dtype)
+    model_pt_fc1 = torch.nn.Linear(in_features, out_features, bias=has_bias1, device=device,
+                                   dtype=dtype)
+    model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
+                                   dtype=dtype)
+    model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1,
+                                bias2=has_bias2, return_residual=return_residual,
+                                checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
+                                device=device, dtype=dtype)
     with torch.no_grad():
         model.fc1.weight.copy_(model_pt_fc1.weight)
-        model.fc1.bias.copy_(model_pt_fc1.bias)
+        if has_bias1:
+            model.fc1.bias.copy_(model_pt_fc1.bias)
         model.fc2.weight.copy_(model_pt_fc2.weight)
-        model.fc2.bias.copy_(model_pt_fc2.bias)
+        if has_bias2:
+            model.fc2.bias.copy_(model_pt_fc2.bias)
     out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
-    out = model(x)
+    if not return_residual:
+        out = model(x)
+    else:
+        out, x_copy = model(x)
+        # Just add some random function of the residual
+        out_pt = out_pt + F.gelu(x_pt)
+        out = out + F.gelu(x_copy)
     assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
 
     # If we don't divide by batch_size, the gradient gets a bit too large.
@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
     assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
     # The error for d_weight and d_bias is quite a bit higher
     assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
-    assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
-
-
-@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
-@pytest.mark.parametrize('out_features', [1024, 4096])
-@pytest.mark.parametrize('in_features', [1024, 4096])
-def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
-    device = 'cuda'
-    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 8
-    seqlen = 512
-    x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
-    x = x_pt.detach().clone().requires_grad_()
-    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
-    model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
-    model = FusedDenseResGeluDense(in_features, out_features, in_features,
-                                   checkpoint_lvl=checkpoint_lvl,
-                                   device=device, dtype=dtype)
-    with torch.no_grad():
-        model.fc1.weight.copy_(model_pt_fc1.weight)
-        model.fc1.bias.copy_(model_pt_fc1.bias)
-        model.fc2.weight.copy_(model_pt_fc2.weight)
-        model.fc2.bias.copy_(model_pt_fc2.bias)
-    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
-    out, x_copy = model(x)
-    out = out + F.gelu(x_copy)
-    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
-
-    # If we don't divide by batch_size, the gradient gets a bit too large.
-    g = torch.randn_like(out) / 32
-    out_pt.backward(g)
-    out.backward(g)
-    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
-    # The error for d_weight and d_bias is quite a bit higher
-    assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
+    if has_bias1:
+        assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
     assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
-    assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
+    if has_bias2:
+        assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)