|
@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
|
|
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
|
|
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
|
|
// setting this to 1M.
|
|
|
- size_t workspaceSize = 1024 * 1024;
|
|
|
+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
|
|
+ // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
|
|
+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
|
|
void* workspace = at::empty(
|
|
|
{static_cast<int64_t>(workspaceSize)},
|
|
|
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
|
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
|
|
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
|
|
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
|
|
// setting this to 1M.
|
|
|
- size_t workspaceSize = 1024 * 1024;
|
|
|
+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
|
|
+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
|
|
void* workspace = at::empty(
|
|
|
{static_cast<int64_t>(workspaceSize)},
|
|
|
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
|
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
|
|
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
|
|
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
|
|
// setting this to 1M.
|
|
|
- size_t workspaceSize = 1024 * 1024;
|
|
|
+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
|
|
+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
|
|
void* workspace = at::empty(
|
|
|
{static_cast<int64_t>(workspaceSize)},
|
|
|
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|