瀏覽代碼

[FusedDense] Allocate lt_workspace on input device

Tri Dao 1 年之前
父節點
當前提交
27f8f890df
共有 2 個文件被更改,包括 76 次插入52 次删除
  1. 28 6
      csrc/fused_dense_lib/fused_dense.cpp
  2. 48 46
      csrc/fused_dense_lib/fused_dense_cuda.cu

+ 28 - 6
csrc/fused_dense_lib/fused_dense.cpp

@@ -2,6 +2,7 @@
 // We make it work for bfloat16
 #include <torch/extension.h>
 #include <torch/torch.h>
+#include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <vector>
 
@@ -28,13 +29,13 @@
   }
 
 template <typename T>
-int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
+int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize);
 
 template <typename T>
-int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
+int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
 
 template <typename T>
-int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
+int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize);
 
 std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
 
@@ -66,6 +67,11 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
     d_bias = at::empty({out_features}, opts);
 #endif
   }
+  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
+  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
+  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
+  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
+  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
 
   DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
     auto result = linear_bias_wgrad_cuda<scalar_t>(
@@ -75,7 +81,9 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
         batch_size,
         out_features,
         d_weight.data_ptr<scalar_t>(),
-        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
+        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
+        (void*) (lt_workspace.data_ptr()),
+        workspaceSize);
     TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
   });
 
@@ -117,6 +125,11 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
   // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
   if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
                                           is_gelu ? opts : opts.dtype(torch::kUInt8)); }
+  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
+  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
+  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
+  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
+  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
 
   DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
     auto result = linear_act_forward_cuda<scalar_t>(
@@ -129,7 +142,9 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
         is_gelu,
         heuristic,
         output.data_ptr<scalar_t>(),
-        save_pre_act ? pre_act.data_ptr() : nullptr);
+        save_pre_act ? pre_act.data_ptr() : nullptr,
+        (void*) (lt_workspace.data_ptr()),
+        workspaceSize);
     TORCH_CHECK(result == 0, "linear_act_forward failed.");
   });
 
@@ -168,6 +183,11 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
   auto opts = weight.options();
   auto d_bias = at::empty({in_features}, opts);
   auto d_input = at::empty({batch_size, in_features}, opts);
+  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
+  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
+  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
+  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
+  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
 
   DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
     auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
@@ -180,7 +200,9 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
         is_gelu,
         heuristic,
         d_input.data_ptr<scalar_t>(),
-        d_bias.data_ptr<scalar_t>());
+        d_bias.data_ptr<scalar_t>(),
+        (void*) (lt_workspace.data_ptr()),
+        workspaceSize);
     TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
   });
 

+ 48 - 46
csrc/fused_dense_lib/fused_dense_cuda.cu

@@ -110,7 +110,9 @@ int gemm_bias_act_lt(
     int64_t ldc,
     void* pre_act,
     bool is_gelu,
-    int heuristic
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize
     ) {
   static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                 "gemm_bias_act_lt only supports fp16 and bf16");
@@ -120,14 +122,6 @@ int gemm_bias_act_lt(
 
   cublasLtHandle_t ltHandle =
     reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
-  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
-  // setting this to 1M.
-  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
-  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
-  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
-  void* workspace = at::empty(
-    {static_cast<int64_t>(workspaceSize)},
-    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
 
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
 
@@ -228,7 +222,7 @@ int gemm_bias_act_lt(
                           // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
                           &heuristicResult[heuristic].algo,
                           // NULL,
-                          workspace,
+                          lt_workspace,
                           workspaceSize,
                           at::cuda::getCurrentCUDAStream());
 
@@ -254,7 +248,9 @@ template int gemm_bias_act_lt(
     int64_t ldc,
     void* pre_act,
     bool is_gelu,
-    int heuristic);
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 template int gemm_bias_act_lt(
     cublasOperation_t transa,
@@ -272,7 +268,9 @@ template int gemm_bias_act_lt(
     int64_t ldc,
     void* pre_act,
     bool is_gelu,
-    int heuristic);
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 template <typename Dtype>
 int gemm_bgradb_lt(
@@ -288,7 +286,9 @@ int gemm_bgradb_lt(
     int64_t ldb,
     Dtype* C,
     int64_t ldc,
-    Dtype* bgrad) {
+    Dtype* bgrad,
+    void *lt_workspace,
+    size_t workspaceSize) {
   static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                 "gemm_bgradb_lt only supports fp16 and bf16");
   float beta = 0.0;
@@ -296,13 +296,6 @@ int gemm_bgradb_lt(
 
   cublasLtHandle_t ltHandle =
     reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
-  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
-  // setting this to 1M.
-  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
-  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
-  void* workspace = at::empty(
-    {static_cast<int64_t>(workspaceSize)},
-    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
 
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
 
@@ -384,7 +377,7 @@ int gemm_bgradb_lt(
                           &Cdesc,
                           //&heuristicResult.algo,
                           NULL,
-                          workspace,
+                          lt_workspace,
                           workspaceSize,
                           at::cuda::getCurrentCUDAStream());
 
@@ -408,7 +401,9 @@ template int gemm_bgradb_lt(
     int64_t ldb,
     at::Half* C,
     int64_t ldc,
-    at::Half* bgrad);
+    at::Half* bgrad,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 template int gemm_bgradb_lt(
     cublasOperation_t transa,
@@ -423,7 +418,9 @@ template int gemm_bgradb_lt(
     int64_t ldb,
     at::BFloat16* C,
     int64_t ldc,
-    at::BFloat16* bgrad);
+    at::BFloat16* bgrad,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 template <typename Dtype>
 int gemm_dact_bgradb_lt(
@@ -442,7 +439,9 @@ int gemm_dact_bgradb_lt(
     int64_t ldc,
     Dtype* bgrad,
     bool is_gelu,
-    int heuristic) {
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize) {
   static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
                 "gemm_dact_bgradb_lt only supports fp16 and bf16");
   float beta = 0.0;
@@ -450,13 +449,6 @@ int gemm_dact_bgradb_lt(
 
   cublasLtHandle_t ltHandle =
     reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
-  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
-  // setting this to 1M.
-  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
-  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
-  void* workspace = at::empty(
-    {static_cast<int64_t>(workspaceSize)},
-    at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
 
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
 
@@ -542,7 +534,7 @@ int gemm_dact_bgradb_lt(
                           //&heuristicResult.algo,
                           &heuristicResult[heuristic].algo,
                           // NULL,
-                          workspace,
+                          lt_workspace,
                           workspaceSize,
                           at::cuda::getCurrentCUDAStream());
 
@@ -568,7 +560,9 @@ template int gemm_dact_bgradb_lt(
     int64_t ldc,
     at::Half* bgrad,
     bool is_gelu,
-    int heuristic);
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 template int gemm_dact_bgradb_lt(
     cublasOperation_t transa,
@@ -586,12 +580,14 @@ template int gemm_dact_bgradb_lt(
     int64_t ldc,
     at::BFloat16* bgrad,
     bool is_gelu,
-    int heuristic);
+    int heuristic,
+    void *lt_workspace,
+    size_t workspaceSize);
 
 #endif
 
 template <typename T>
-int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) {
+int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
     const float alpha          = 1.0;
     const float beta_zero      = 0.0;
     int status = 1;
@@ -610,7 +606,9 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
     out_features,
     d_weight,
     in_features,
-    d_bias);
+    d_bias,
+    lt_workspace,
+    workspaceSize);
 #endif
 
     if (status != 0){
@@ -652,7 +650,7 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
 }
 
 template <typename T>
-int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) {
+int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
     int status = 1;
 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
     status = gemm_bias_act_lt(
@@ -671,7 +669,9 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
     out_features,
     pre_act,
     is_gelu,
-    heuristic);
+    heuristic,
+    lt_workspace,
+    workspaceSize);
     return status;
 #else
     return 1;
@@ -679,7 +679,7 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
 }
 
 template <typename T>
-int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) {
+int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
     const float alpha          = 1.0;
     int status = 1;
 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
@@ -699,17 +699,19 @@ int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const v
     in_features,
     d_bias,
     is_gelu,
-    heuristic);
+    heuristic,
+    lt_workspace,
+    workspaceSize);
 #endif
     return status;
 
 }
 
-template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias);
-template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias);
+template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
+template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
 
-template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act);
-template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act);
+template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
+template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
 
-template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias);
-template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias);
+template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
+template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);