// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu #include #include #include #include #include #include #include /* Includes, cuda */ #include #include #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #include #endif // FP16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, const float* alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const float* beta, at::Half* C, int64_t ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } // BF16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, const float* alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const float* beta, at::BFloat16* C, int64_t ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_16BF, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, const Dtype* bias, Dtype* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize ) { static_assert(std::is_same::value || std::is_same::value, "gemm_bias_act_lt only supports fp16 and bf16"); bool save_pre_act = pre_act != nullptr; float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; // constexpr int requestedAlgoCount = 1; // cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtEpilogue_t epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (save_pre_act) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); } if (bias != nullptr) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS); } else { epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, // &heuristicResult.algo, // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos &heuristicResult[heuristic].algo, // NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const at::Half* bias, at::Half* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const at::BFloat16* bias, at::BFloat16* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, Dtype* C, int64_t ldc, Dtype* bgrad, void *lt_workspace, size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (bgrad != nullptr) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } epilogue = CUBLASLT_EPILOGUE_BGRADB; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, //&heuristicResult.algo, NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, at::Half* C, int64_t ldc, at::Half* bgrad, void *lt_workspace, size_t workspaceSize); template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, at::BFloat16* C, int64_t ldc, at::BFloat16* bgrad, void *lt_workspace, size_t workspaceSize); template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, const void* pre_act, Dtype* C, int64_t ldc, Dtype* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_dact_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD; // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, //&heuristicResult.algo, &heuristicResult[heuristic].algo, // NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const void* pre_act, at::Half* C, int64_t ldc, at::Half* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const void* pre_act, at::BFloat16* C, int64_t ldc, at::BFloat16* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); #endif template int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; const float beta_zero = 0.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bgradb_lt( // (cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, alpha, input, in_features, d_output, out_features, d_weight, in_features, d_bias, lt_workspace, workspaceSize); #endif if (status != 0){ cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); status = gemm_bias( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, &alpha, input, in_features, d_output, out_features, &beta_zero, d_weight, in_features); // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341 // at::cuda::blas::gemm( // 'N', // 'T', // in_features, // out_features, // batch_size, // alpha, // input, // in_features, // d_output, // out_features, // beta_zero, // d_weight, // in_features); } return status; } template int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) { int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bias_act_lt( CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, in_features, /*alpha=*/1.0, weight, in_features, input, in_features, bias, output, out_features, pre_act, is_gelu, heuristic, lt_workspace, workspaceSize); return status; #else return 1; #endif } template int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_dact_bgradb_lt( CUBLAS_OP_N, CUBLAS_OP_N, in_features, batch_size, out_features, alpha, weight, in_features, d_output, out_features, pre_act, d_input, in_features, d_bias, is_gelu, heuristic, lt_workspace, workspaceSize); #endif return status; } template int linear_bias_wgrad_cuda(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); template int linear_bias_wgrad_cuda(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize); template int linear_act_forward_cuda(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template int linear_act_forward_cuda(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template int bias_act_linear_dgrad_bgrad_cuda(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); template int bias_act_linear_dgrad_bgrad_cuda(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);