123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717 |
- // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <assert.h>
- #include <stdio.h>
- #include <stdlib.h>
- #include <string.h>
- #include <torch/torch.h>
- /* Includes, cuda */
- #include <cublas_v2.h>
- #include <cuda_runtime.h>
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
- #include <cublasLt.h>
- #endif
- // FP16 Tensor core wrapper around cublas GEMMEx
- cublasStatus_t gemm_bias(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- const float* alpha,
- const at::Half* A,
- int64_t lda,
- const at::Half* B,
- int64_t ldb,
- const float* beta,
- at::Half* C,
- int64_t ldc) {
- return cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_16F,
- lda,
- B,
- CUDA_R_16F,
- ldb,
- beta,
- C,
- CUDA_R_16F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP);
- }
- // BF16 Tensor core wrapper around cublas GEMMEx
- cublasStatus_t gemm_bias(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- const float* alpha,
- const at::BFloat16* A,
- int64_t lda,
- const at::BFloat16* B,
- int64_t ldb,
- const float* beta,
- at::BFloat16* C,
- int64_t ldc) {
- return cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_16BF,
- lda,
- B,
- CUDA_R_16BF,
- ldb,
- beta,
- C,
- CUDA_R_16BF,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP);
- }
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
- template <typename Dtype>
- int gemm_bias_act_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const Dtype* A,
- int64_t lda,
- const Dtype* B,
- int64_t ldb,
- const Dtype* bias,
- Dtype* C,
- int64_t ldc,
- void* pre_act,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize
- ) {
- static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
- "gemm_bias_act_lt only supports fp16 and bf16");
- bool save_pre_act = pre_act != nullptr;
- float beta = 0.0;
- cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
- cublasLtHandle_t ltHandle =
- reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
- cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
- cublasLtMatmulDescOpaque_t operationDesc = {};
- cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
- cublasLtMatmulPreferenceOpaque_t preference = {};
- int returnedResults = 0;
- constexpr int requestedAlgoCount = 5;
- cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
- // constexpr int requestedAlgoCount = 1;
- // cublasLtMatmulHeuristicResult_t heuristicResult = {};
- cublasLtEpilogue_t epilogue = is_gelu
- ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
- : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
- // Create operation descriptor; see cublasLtMatmulDescAttributes_t
- // for details about defaults; here we just set the transforms for
- // A and B.
- status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (save_pre_act) {
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
- }
- if (bias != nullptr) {
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- epilogue = is_gelu
- ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)
- : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);
- } else {
- epilogue = is_gelu
- ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
- : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- // Create matrix descriptors. Not setting any extra attributes.
- status = cublasLtMatrixLayoutInit(
- &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(
- &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // Create preference handle; In general, extra attributes can be
- // used here to disable tensor ops or to make sure algo selected
- // will work with badly aligned A, B, C. However, for simplicity
- // here we assume A,B,C are always well aligned (e.g., directly
- // come from cudaMalloc)
- status = cublasLtMatmulPreferenceInit(&preference);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulPreferenceSetAttribute(
- &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // We just need the best available heuristic to try and run matmul.
- // There is no guarantee that this will work. For example, if A is
- // badly aligned, you can request more (e.g. 32) algos and try to
- // run them one by one until something works.
- status = cublasLtMatmulAlgoGetHeuristic(
- ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
- // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (returnedResults == 0) {
- status = CUBLAS_STATUS_NOT_SUPPORTED;
- goto CLEANUP;
- }
- status = cublasLtMatmul(ltHandle,
- &operationDesc,
- &alpha,
- A,
- &Adesc,
- B,
- &Bdesc,
- &beta,
- C,
- &Cdesc,
- C,
- &Cdesc,
- // &heuristicResult.algo,
- // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
- &heuristicResult[heuristic].algo,
- // NULL,
- lt_workspace,
- workspaceSize,
- at::cuda::getCurrentCUDAStream());
- CLEANUP:
- // Descriptors are no longer needed as all GPU work was already
- // enqueued.
- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
- }
- template int gemm_bias_act_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::Half* A,
- int64_t lda,
- const at::Half* B,
- int64_t ldb,
- const at::Half* bias,
- at::Half* C,
- int64_t ldc,
- void* pre_act,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize);
- template int gemm_bias_act_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::BFloat16* A,
- int64_t lda,
- const at::BFloat16* B,
- int64_t ldb,
- const at::BFloat16* bias,
- at::BFloat16* C,
- int64_t ldc,
- void* pre_act,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize);
- template <typename Dtype>
- int gemm_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const Dtype* A,
- int64_t lda,
- const Dtype* B,
- int64_t ldb,
- Dtype* C,
- int64_t ldc,
- Dtype* bgrad,
- void *lt_workspace,
- size_t workspaceSize) {
- static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
- "gemm_bgradb_lt only supports fp16 and bf16");
- float beta = 0.0;
- cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
- cublasLtHandle_t ltHandle =
- reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
- cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
- cublasLtMatmulDescOpaque_t operationDesc = {};
- cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
- cublasLtMatmulPreferenceOpaque_t preference = {};
- int returnedResults = 0;
- cublasLtMatmulHeuristicResult_t heuristicResult = {};
- cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
- // Create operation descriptor; see cublasLtMatmulDescAttributes_t
- // for details about defaults; here we just set the transforms for
- // A and B.
- status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (bgrad != nullptr) {
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- epilogue = CUBLASLT_EPILOGUE_BGRADB;
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- // Create matrix descriptors. Not setting any extra attributes.
- status = cublasLtMatrixLayoutInit(
- &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(
- &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // Create preference handle; In general, extra attributes can be
- // used here to disable tensor ops or to make sure algo selected
- // will work with badly aligned A, B, C. However, for simplicity
- // here we assume A,B,C are always well aligned (e.g., directly
- // come from cudaMalloc)
- status = cublasLtMatmulPreferenceInit(&preference);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulPreferenceSetAttribute(
- &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // We just need the best available heuristic to try and run matmul.
- // There is no guarantee that this will work. For example, if A is
- // badly aligned, you can request more (e.g. 32) algos and try to
- // run them one by one until something works.
- status = cublasLtMatmulAlgoGetHeuristic(
- ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (returnedResults == 0) {
- status = CUBLAS_STATUS_NOT_SUPPORTED;
- goto CLEANUP;
- }
- status = cublasLtMatmul(ltHandle,
- &operationDesc,
- &alpha,
- A,
- &Adesc,
- B,
- &Bdesc,
- &beta,
- C,
- &Cdesc,
- C,
- &Cdesc,
- //&heuristicResult.algo,
- NULL,
- lt_workspace,
- workspaceSize,
- at::cuda::getCurrentCUDAStream());
- CLEANUP:
- // Descriptors are no longer needed as all GPU work was already
- // enqueued.
- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
- }
- template int gemm_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::Half* A,
- int64_t lda,
- const at::Half* B,
- int64_t ldb,
- at::Half* C,
- int64_t ldc,
- at::Half* bgrad,
- void *lt_workspace,
- size_t workspaceSize);
- template int gemm_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::BFloat16* A,
- int64_t lda,
- const at::BFloat16* B,
- int64_t ldb,
- at::BFloat16* C,
- int64_t ldc,
- at::BFloat16* bgrad,
- void *lt_workspace,
- size_t workspaceSize);
- template <typename Dtype>
- int gemm_dact_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const Dtype* A,
- int64_t lda,
- const Dtype* B,
- int64_t ldb,
- const void* pre_act,
- Dtype* C,
- int64_t ldc,
- Dtype* bgrad,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize) {
- static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
- "gemm_dact_bgradb_lt only supports fp16 and bf16");
- float beta = 0.0;
- cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
- cublasLtHandle_t ltHandle =
- reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
- cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
- cublasLtMatmulDescOpaque_t operationDesc = {};
- cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
- cublasLtMatmulPreferenceOpaque_t preference = {};
- int returnedResults = 0;
- constexpr int requestedAlgoCount = 5;
- cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
- cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;
- // Create operation descriptor; see cublasLtMatmulDescAttributes_t
- // for details about defaults; here we just set the transforms for
- // A and B.
- status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- // Create matrix descriptors. Not setting any extra attributes.
- status = cublasLtMatrixLayoutInit(
- &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(
- &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // Create preference handle; In general, extra attributes can be
- // used here to disable tensor ops or to make sure algo selected
- // will work with badly aligned A, B, C. However, for simplicity
- // here we assume A,B,C are always well aligned (e.g., directly
- // come from cudaMalloc)
- status = cublasLtMatmulPreferenceInit(&preference);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulPreferenceSetAttribute(
- &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // We just need the best available heuristic to try and run matmul.
- // There is no guarantee that this will work. For example, if A is
- // badly aligned, you can request more (e.g. 32) algos and try to
- // run them one by one until something works.
- status = cublasLtMatmulAlgoGetHeuristic(
- ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (returnedResults == 0) {
- status = CUBLAS_STATUS_NOT_SUPPORTED;
- goto CLEANUP;
- }
- status = cublasLtMatmul(ltHandle,
- &operationDesc,
- &alpha,
- A,
- &Adesc,
- B,
- &Bdesc,
- &beta,
- C,
- &Cdesc,
- C,
- &Cdesc,
- //&heuristicResult.algo,
- &heuristicResult[heuristic].algo,
- // NULL,
- lt_workspace,
- workspaceSize,
- at::cuda::getCurrentCUDAStream());
- CLEANUP:
- // Descriptors are no longer needed as all GPU work was already
- // enqueued.
- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
- }
- template int gemm_dact_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::Half* A,
- int64_t lda,
- const at::Half* B,
- int64_t ldb,
- const void* pre_act,
- at::Half* C,
- int64_t ldc,
- at::Half* bgrad,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize);
- template int gemm_dact_bgradb_lt(
- cublasOperation_t transa,
- cublasOperation_t transb,
- int64_t m,
- int64_t n,
- int64_t k,
- float alpha,
- const at::BFloat16* A,
- int64_t lda,
- const at::BFloat16* B,
- int64_t ldb,
- const void* pre_act,
- at::BFloat16* C,
- int64_t ldc,
- at::BFloat16* bgrad,
- bool is_gelu,
- int heuristic,
- void *lt_workspace,
- size_t workspaceSize);
- #endif
- template <typename T>
- int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
- const float alpha = 1.0;
- const float beta_zero = 0.0;
- int status = 1;
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
- status = gemm_bgradb_lt(
- // (cublasLtHandle_t)handle,
- CUBLAS_OP_N,
- CUBLAS_OP_T,
- in_features,
- out_features,
- batch_size,
- alpha,
- input,
- in_features,
- d_output,
- out_features,
- d_weight,
- in_features,
- d_bias,
- lt_workspace,
- workspaceSize);
- #endif
- if (status != 0){
- cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
- status = gemm_bias(
- handle,
- CUBLAS_OP_N,
- CUBLAS_OP_T,
- in_features,
- out_features,
- batch_size,
- &alpha,
- input,
- in_features,
- d_output,
- out_features,
- &beta_zero,
- d_weight,
- in_features);
- // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error
- // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341
- // at::cuda::blas::gemm<T>(
- // 'N',
- // 'T',
- // in_features,
- // out_features,
- // batch_size,
- // alpha,
- // input,
- // in_features,
- // d_output,
- // out_features,
- // beta_zero,
- // d_weight,
- // in_features);
- }
- return status;
- }
- template <typename T>
- int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
- int status = 1;
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
- status = gemm_bias_act_lt(
- CUBLAS_OP_T,
- CUBLAS_OP_N,
- out_features,
- batch_size,
- in_features,
- /*alpha=*/1.0,
- weight,
- in_features,
- input,
- in_features,
- bias,
- output,
- out_features,
- pre_act,
- is_gelu,
- heuristic,
- lt_workspace,
- workspaceSize);
- return status;
- #else
- return 1;
- #endif
- }
- template <typename T>
- int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
- const float alpha = 1.0;
- int status = 1;
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
- status = gemm_dact_bgradb_lt(
- CUBLAS_OP_N,
- CUBLAS_OP_N,
- in_features,
- batch_size,
- out_features,
- alpha,
- weight,
- in_features,
- d_output,
- out_features,
- pre_act,
- d_input,
- in_features,
- d_bias,
- is_gelu,
- heuristic,
- lt_workspace,
- workspaceSize);
- #endif
- return status;
- }
- template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
- template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
- template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
- template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
- template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
- template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
|