1
0

fused_dense_cuda.cu 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu
  2. #include <ATen/ATen.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <assert.h>
  5. #include <stdio.h>
  6. #include <stdlib.h>
  7. #include <string.h>
  8. #include <torch/torch.h>
  9. /* Includes, cuda */
  10. #include <cublas_v2.h>
  11. #include <cuda_runtime.h>
  12. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
  13. #include <cublasLt.h>
  14. #endif
  15. // FP16 Tensor core wrapper around cublas GEMMEx
  16. cublasStatus_t gemm_bias(
  17. cublasHandle_t handle,
  18. cublasOperation_t transa,
  19. cublasOperation_t transb,
  20. int64_t m,
  21. int64_t n,
  22. int64_t k,
  23. const float* alpha,
  24. const at::Half* A,
  25. int64_t lda,
  26. const at::Half* B,
  27. int64_t ldb,
  28. const float* beta,
  29. at::Half* C,
  30. int64_t ldc) {
  31. return cublasGemmEx(
  32. handle,
  33. transa,
  34. transb,
  35. m,
  36. n,
  37. k,
  38. alpha,
  39. A,
  40. CUDA_R_16F,
  41. lda,
  42. B,
  43. CUDA_R_16F,
  44. ldb,
  45. beta,
  46. C,
  47. CUDA_R_16F,
  48. ldc,
  49. CUDA_R_32F,
  50. CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  51. }
  52. // BF16 Tensor core wrapper around cublas GEMMEx
  53. cublasStatus_t gemm_bias(
  54. cublasHandle_t handle,
  55. cublasOperation_t transa,
  56. cublasOperation_t transb,
  57. int64_t m,
  58. int64_t n,
  59. int64_t k,
  60. const float* alpha,
  61. const at::BFloat16* A,
  62. int64_t lda,
  63. const at::BFloat16* B,
  64. int64_t ldb,
  65. const float* beta,
  66. at::BFloat16* C,
  67. int64_t ldc) {
  68. return cublasGemmEx(
  69. handle,
  70. transa,
  71. transb,
  72. m,
  73. n,
  74. k,
  75. alpha,
  76. A,
  77. CUDA_R_16BF,
  78. lda,
  79. B,
  80. CUDA_R_16BF,
  81. ldb,
  82. beta,
  83. C,
  84. CUDA_R_16BF,
  85. ldc,
  86. CUDA_R_32F,
  87. CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  88. }
  89. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  90. template <typename Dtype>
  91. int gemm_bias_act_lt(
  92. cublasOperation_t transa,
  93. cublasOperation_t transb,
  94. int64_t m,
  95. int64_t n,
  96. int64_t k,
  97. float alpha,
  98. const Dtype* A,
  99. int64_t lda,
  100. const Dtype* B,
  101. int64_t ldb,
  102. const Dtype* bias,
  103. Dtype* C,
  104. int64_t ldc,
  105. void* pre_act,
  106. bool is_gelu,
  107. int heuristic,
  108. void *lt_workspace,
  109. size_t workspaceSize
  110. ) {
  111. static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
  112. "gemm_bias_act_lt only supports fp16 and bf16");
  113. bool save_pre_act = pre_act != nullptr;
  114. float beta = 0.0;
  115. cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
  116. cublasLtHandle_t ltHandle =
  117. reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  118. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  119. cublasLtMatmulDescOpaque_t operationDesc = {};
  120. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  121. cublasLtMatmulPreferenceOpaque_t preference = {};
  122. int returnedResults = 0;
  123. constexpr int requestedAlgoCount = 5;
  124. cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
  125. // constexpr int requestedAlgoCount = 1;
  126. // cublasLtMatmulHeuristicResult_t heuristicResult = {};
  127. cublasLtEpilogue_t epilogue = is_gelu
  128. ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
  129. : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
  130. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  131. // for details about defaults; here we just set the transforms for
  132. // A and B.
  133. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  134. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  135. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  136. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  137. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  138. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  139. if (save_pre_act) {
  140. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
  141. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  142. }
  143. if (bias != nullptr) {
  144. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  145. if (status != CUBLAS_STATUS_SUCCESS) {
  146. goto CLEANUP;
  147. }
  148. epilogue = is_gelu
  149. ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)
  150. : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);
  151. } else {
  152. epilogue = is_gelu
  153. ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)
  154. : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);
  155. }
  156. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  157. if (status != CUBLAS_STATUS_SUCCESS) {
  158. goto CLEANUP;
  159. }
  160. // Create matrix descriptors. Not setting any extra attributes.
  161. status = cublasLtMatrixLayoutInit(
  162. &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  163. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  164. status = cublasLtMatrixLayoutInit(
  165. &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  166. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  167. status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
  168. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  169. // Create preference handle; In general, extra attributes can be
  170. // used here to disable tensor ops or to make sure algo selected
  171. // will work with badly aligned A, B, C. However, for simplicity
  172. // here we assume A,B,C are always well aligned (e.g., directly
  173. // come from cudaMalloc)
  174. status = cublasLtMatmulPreferenceInit(&preference);
  175. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  176. status = cublasLtMatmulPreferenceSetAttribute(
  177. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  178. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  179. // We just need the best available heuristic to try and run matmul.
  180. // There is no guarantee that this will work. For example, if A is
  181. // badly aligned, you can request more (e.g. 32) algos and try to
  182. // run them one by one until something works.
  183. status = cublasLtMatmulAlgoGetHeuristic(
  184. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
  185. // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  186. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  187. if (returnedResults == 0) {
  188. status = CUBLAS_STATUS_NOT_SUPPORTED;
  189. goto CLEANUP;
  190. }
  191. status = cublasLtMatmul(ltHandle,
  192. &operationDesc,
  193. &alpha,
  194. A,
  195. &Adesc,
  196. B,
  197. &Bdesc,
  198. &beta,
  199. C,
  200. &Cdesc,
  201. C,
  202. &Cdesc,
  203. // &heuristicResult.algo,
  204. // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
  205. &heuristicResult[heuristic].algo,
  206. // NULL,
  207. lt_workspace,
  208. workspaceSize,
  209. at::cuda::getCurrentCUDAStream());
  210. CLEANUP:
  211. // Descriptors are no longer needed as all GPU work was already
  212. // enqueued.
  213. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  214. }
  215. template int gemm_bias_act_lt(
  216. cublasOperation_t transa,
  217. cublasOperation_t transb,
  218. int64_t m,
  219. int64_t n,
  220. int64_t k,
  221. float alpha,
  222. const at::Half* A,
  223. int64_t lda,
  224. const at::Half* B,
  225. int64_t ldb,
  226. const at::Half* bias,
  227. at::Half* C,
  228. int64_t ldc,
  229. void* pre_act,
  230. bool is_gelu,
  231. int heuristic,
  232. void *lt_workspace,
  233. size_t workspaceSize);
  234. template int gemm_bias_act_lt(
  235. cublasOperation_t transa,
  236. cublasOperation_t transb,
  237. int64_t m,
  238. int64_t n,
  239. int64_t k,
  240. float alpha,
  241. const at::BFloat16* A,
  242. int64_t lda,
  243. const at::BFloat16* B,
  244. int64_t ldb,
  245. const at::BFloat16* bias,
  246. at::BFloat16* C,
  247. int64_t ldc,
  248. void* pre_act,
  249. bool is_gelu,
  250. int heuristic,
  251. void *lt_workspace,
  252. size_t workspaceSize);
  253. template <typename Dtype>
  254. int gemm_bgradb_lt(
  255. cublasOperation_t transa,
  256. cublasOperation_t transb,
  257. int64_t m,
  258. int64_t n,
  259. int64_t k,
  260. float alpha,
  261. const Dtype* A,
  262. int64_t lda,
  263. const Dtype* B,
  264. int64_t ldb,
  265. Dtype* C,
  266. int64_t ldc,
  267. Dtype* bgrad,
  268. void *lt_workspace,
  269. size_t workspaceSize) {
  270. static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
  271. "gemm_bgradb_lt only supports fp16 and bf16");
  272. float beta = 0.0;
  273. cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
  274. cublasLtHandle_t ltHandle =
  275. reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  276. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  277. cublasLtMatmulDescOpaque_t operationDesc = {};
  278. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  279. cublasLtMatmulPreferenceOpaque_t preference = {};
  280. int returnedResults = 0;
  281. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  282. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  283. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  284. // for details about defaults; here we just set the transforms for
  285. // A and B.
  286. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  287. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  288. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  289. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  290. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  291. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  292. if (bgrad != nullptr) {
  293. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  294. if (status != CUBLAS_STATUS_SUCCESS) {
  295. goto CLEANUP;
  296. }
  297. epilogue = CUBLASLT_EPILOGUE_BGRADB;
  298. }
  299. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  300. if (status != CUBLAS_STATUS_SUCCESS) {
  301. goto CLEANUP;
  302. }
  303. // Create matrix descriptors. Not setting any extra attributes.
  304. status = cublasLtMatrixLayoutInit(
  305. &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  306. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  307. status = cublasLtMatrixLayoutInit(
  308. &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  309. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  310. status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
  311. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  312. // Create preference handle; In general, extra attributes can be
  313. // used here to disable tensor ops or to make sure algo selected
  314. // will work with badly aligned A, B, C. However, for simplicity
  315. // here we assume A,B,C are always well aligned (e.g., directly
  316. // come from cudaMalloc)
  317. status = cublasLtMatmulPreferenceInit(&preference);
  318. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  319. status = cublasLtMatmulPreferenceSetAttribute(
  320. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  321. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  322. // We just need the best available heuristic to try and run matmul.
  323. // There is no guarantee that this will work. For example, if A is
  324. // badly aligned, you can request more (e.g. 32) algos and try to
  325. // run them one by one until something works.
  326. status = cublasLtMatmulAlgoGetHeuristic(
  327. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  328. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  329. if (returnedResults == 0) {
  330. status = CUBLAS_STATUS_NOT_SUPPORTED;
  331. goto CLEANUP;
  332. }
  333. status = cublasLtMatmul(ltHandle,
  334. &operationDesc,
  335. &alpha,
  336. A,
  337. &Adesc,
  338. B,
  339. &Bdesc,
  340. &beta,
  341. C,
  342. &Cdesc,
  343. C,
  344. &Cdesc,
  345. //&heuristicResult.algo,
  346. NULL,
  347. lt_workspace,
  348. workspaceSize,
  349. at::cuda::getCurrentCUDAStream());
  350. CLEANUP:
  351. // Descriptors are no longer needed as all GPU work was already
  352. // enqueued.
  353. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  354. }
  355. template int gemm_bgradb_lt(
  356. cublasOperation_t transa,
  357. cublasOperation_t transb,
  358. int64_t m,
  359. int64_t n,
  360. int64_t k,
  361. float alpha,
  362. const at::Half* A,
  363. int64_t lda,
  364. const at::Half* B,
  365. int64_t ldb,
  366. at::Half* C,
  367. int64_t ldc,
  368. at::Half* bgrad,
  369. void *lt_workspace,
  370. size_t workspaceSize);
  371. template int gemm_bgradb_lt(
  372. cublasOperation_t transa,
  373. cublasOperation_t transb,
  374. int64_t m,
  375. int64_t n,
  376. int64_t k,
  377. float alpha,
  378. const at::BFloat16* A,
  379. int64_t lda,
  380. const at::BFloat16* B,
  381. int64_t ldb,
  382. at::BFloat16* C,
  383. int64_t ldc,
  384. at::BFloat16* bgrad,
  385. void *lt_workspace,
  386. size_t workspaceSize);
  387. template <typename Dtype>
  388. int gemm_dact_bgradb_lt(
  389. cublasOperation_t transa,
  390. cublasOperation_t transb,
  391. int64_t m,
  392. int64_t n,
  393. int64_t k,
  394. float alpha,
  395. const Dtype* A,
  396. int64_t lda,
  397. const Dtype* B,
  398. int64_t ldb,
  399. const void* pre_act,
  400. Dtype* C,
  401. int64_t ldc,
  402. Dtype* bgrad,
  403. bool is_gelu,
  404. int heuristic,
  405. void *lt_workspace,
  406. size_t workspaceSize) {
  407. static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
  408. "gemm_dact_bgradb_lt only supports fp16 and bf16");
  409. float beta = 0.0;
  410. cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;
  411. cublasLtHandle_t ltHandle =
  412. reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  413. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  414. cublasLtMatmulDescOpaque_t operationDesc = {};
  415. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  416. cublasLtMatmulPreferenceOpaque_t preference = {};
  417. int returnedResults = 0;
  418. constexpr int requestedAlgoCount = 5;
  419. cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};
  420. cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;
  421. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  422. // for details about defaults; here we just set the transforms for
  423. // A and B.
  424. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  425. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  426. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  427. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  428. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  429. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  430. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  431. if (status != CUBLAS_STATUS_SUCCESS) {
  432. goto CLEANUP;
  433. }
  434. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));
  435. if (status != CUBLAS_STATUS_SUCCESS) {
  436. goto CLEANUP;
  437. }
  438. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  439. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  440. if (status != CUBLAS_STATUS_SUCCESS) {
  441. goto CLEANUP;
  442. }
  443. // Create matrix descriptors. Not setting any extra attributes.
  444. status = cublasLtMatrixLayoutInit(
  445. &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  446. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  447. status = cublasLtMatrixLayoutInit(
  448. &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  449. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  450. status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);
  451. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  452. // Create preference handle; In general, extra attributes can be
  453. // used here to disable tensor ops or to make sure algo selected
  454. // will work with badly aligned A, B, C. However, for simplicity
  455. // here we assume A,B,C are always well aligned (e.g., directly
  456. // come from cudaMalloc)
  457. status = cublasLtMatmulPreferenceInit(&preference);
  458. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  459. status = cublasLtMatmulPreferenceSetAttribute(
  460. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  461. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  462. // We just need the best available heuristic to try and run matmul.
  463. // There is no guarantee that this will work. For example, if A is
  464. // badly aligned, you can request more (e.g. 32) algos and try to
  465. // run them one by one until something works.
  466. status = cublasLtMatmulAlgoGetHeuristic(
  467. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);
  468. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  469. if (returnedResults == 0) {
  470. status = CUBLAS_STATUS_NOT_SUPPORTED;
  471. goto CLEANUP;
  472. }
  473. status = cublasLtMatmul(ltHandle,
  474. &operationDesc,
  475. &alpha,
  476. A,
  477. &Adesc,
  478. B,
  479. &Bdesc,
  480. &beta,
  481. C,
  482. &Cdesc,
  483. C,
  484. &Cdesc,
  485. //&heuristicResult.algo,
  486. &heuristicResult[heuristic].algo,
  487. // NULL,
  488. lt_workspace,
  489. workspaceSize,
  490. at::cuda::getCurrentCUDAStream());
  491. CLEANUP:
  492. // Descriptors are no longer needed as all GPU work was already
  493. // enqueued.
  494. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  495. }
  496. template int gemm_dact_bgradb_lt(
  497. cublasOperation_t transa,
  498. cublasOperation_t transb,
  499. int64_t m,
  500. int64_t n,
  501. int64_t k,
  502. float alpha,
  503. const at::Half* A,
  504. int64_t lda,
  505. const at::Half* B,
  506. int64_t ldb,
  507. const void* pre_act,
  508. at::Half* C,
  509. int64_t ldc,
  510. at::Half* bgrad,
  511. bool is_gelu,
  512. int heuristic,
  513. void *lt_workspace,
  514. size_t workspaceSize);
  515. template int gemm_dact_bgradb_lt(
  516. cublasOperation_t transa,
  517. cublasOperation_t transb,
  518. int64_t m,
  519. int64_t n,
  520. int64_t k,
  521. float alpha,
  522. const at::BFloat16* A,
  523. int64_t lda,
  524. const at::BFloat16* B,
  525. int64_t ldb,
  526. const void* pre_act,
  527. at::BFloat16* C,
  528. int64_t ldc,
  529. at::BFloat16* bgrad,
  530. bool is_gelu,
  531. int heuristic,
  532. void *lt_workspace,
  533. size_t workspaceSize);
  534. #endif
  535. template <typename T>
  536. int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
  537. const float alpha = 1.0;
  538. const float beta_zero = 0.0;
  539. int status = 1;
  540. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  541. status = gemm_bgradb_lt(
  542. // (cublasLtHandle_t)handle,
  543. CUBLAS_OP_N,
  544. CUBLAS_OP_T,
  545. in_features,
  546. out_features,
  547. batch_size,
  548. alpha,
  549. input,
  550. in_features,
  551. d_output,
  552. out_features,
  553. d_weight,
  554. in_features,
  555. d_bias,
  556. lt_workspace,
  557. workspaceSize);
  558. #endif
  559. if (status != 0){
  560. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  561. status = gemm_bias(
  562. handle,
  563. CUBLAS_OP_N,
  564. CUBLAS_OP_T,
  565. in_features,
  566. out_features,
  567. batch_size,
  568. &alpha,
  569. input,
  570. in_features,
  571. d_output,
  572. out_features,
  573. &beta_zero,
  574. d_weight,
  575. in_features);
  576. // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error
  577. // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341
  578. // at::cuda::blas::gemm<T>(
  579. // 'N',
  580. // 'T',
  581. // in_features,
  582. // out_features,
  583. // batch_size,
  584. // alpha,
  585. // input,
  586. // in_features,
  587. // d_output,
  588. // out_features,
  589. // beta_zero,
  590. // d_weight,
  591. // in_features);
  592. }
  593. return status;
  594. }
  595. template <typename T>
  596. int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
  597. int status = 1;
  598. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  599. status = gemm_bias_act_lt(
  600. CUBLAS_OP_T,
  601. CUBLAS_OP_N,
  602. out_features,
  603. batch_size,
  604. in_features,
  605. /*alpha=*/1.0,
  606. weight,
  607. in_features,
  608. input,
  609. in_features,
  610. bias,
  611. output,
  612. out_features,
  613. pre_act,
  614. is_gelu,
  615. heuristic,
  616. lt_workspace,
  617. workspaceSize);
  618. return status;
  619. #else
  620. return 1;
  621. #endif
  622. }
  623. template <typename T>
  624. int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
  625. const float alpha = 1.0;
  626. int status = 1;
  627. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  628. status = gemm_dact_bgradb_lt(
  629. CUBLAS_OP_N,
  630. CUBLAS_OP_N,
  631. in_features,
  632. batch_size,
  633. out_features,
  634. alpha,
  635. weight,
  636. in_features,
  637. d_output,
  638. out_features,
  639. pre_act,
  640. d_input,
  641. in_features,
  642. d_bias,
  643. is_gelu,
  644. heuristic,
  645. lt_workspace,
  646. workspaceSize);
  647. #endif
  648. return status;
  649. }
  650. template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
  651. template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
  652. template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
  653. template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
  654. template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
  655. template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);