1
0

fp6_linear.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is adapted from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
  17. #include "kernel_matmul.cuh"
  18. #include "kernel_reduction.cuh"
  19. #include <stdio.h>
  20. #include <assert.h>
  21. namespace aphrodite {
  22. template <typename TilingConfig, typename OutputDataType, int EXPONENT,
  23. int MANTISSA>
  24. static void Kernel_Ex(cudaStream_t stream, const uint4* Weight,
  25. const half* Scales, const half* B, OutputDataType* C,
  26. const size_t M_Global, const size_t N_Global,
  27. const size_t K_Global, int Split_K) {
  28. #ifdef DEBUG_MODE
  29. printf("\n");
  30. printf("Launcher.cu->Kernel_Ex():\n");
  31. printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global,
  32. Split_K);
  33. printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M,
  34. TilingConfig::TILE_K, TilingConfig::TILE_N);
  35. #endif
  36. static size_t SHMEM_SZ =
  37. max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE,
  38. TilingConfig::SMEM_SIZE_C_TILE);
  39. cudaFuncSetAttribute(
  40. QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>,
  41. cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
  42. size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
  43. size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
  44. dim3 GridDim(dimN, dimM, 1);
  45. dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
  46. //
  47. #ifdef DEBUG_MODE
  48. printf(
  49. "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, "
  50. "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
  51. GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z,
  52. SHMEM_SZ);
  53. printf("\n");
  54. #endif
  55. QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>
  56. <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global,
  57. N_Global, K_Global, Split_K);
  58. }
  59. template <int EXPONENT, int MANTISSA>
  60. cudaError_t fpx_linear_kernel(
  61. cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B,
  62. half* C, const size_t M_Global, const size_t N_Global,
  63. const size_t K_Global,
  64. float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K *
  65. // M_Global * N_Global * sizeof(fp32)
  66. int Split_K) {
  67. assert(M_Global % 256 == 0);
  68. assert(K_Global % 64 == 0);
  69. assert(N_Global > 0);
  70. // Work around to support more N shapes:
  71. size_t N_PowerOf2;
  72. if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8;
  73. if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16;
  74. if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32;
  75. if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64;
  76. if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128;
  77. if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128;
  78. if (Split_K == 1) {
  79. switch (N_PowerOf2) {
  80. case 8:
  81. Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(
  82. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  83. Split_K);
  84. break;
  85. case 16:
  86. Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(
  87. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  88. Split_K);
  89. break;
  90. case 32:
  91. Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(
  92. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  93. Split_K);
  94. break;
  95. case 64:
  96. Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
  97. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  98. Split_K);
  99. break;
  100. case 128:
  101. Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
  102. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  103. Split_K);
  104. break;
  105. default:
  106. if (N_PowerOf2 % 128 != 0) {
  107. printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
  108. return cudaErrorUnknown;
  109. }
  110. Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
  111. stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
  112. Split_K);
  113. break;
  114. }
  115. } else {
  116. switch (N_PowerOf2) {
  117. case 8:
  118. Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(
  119. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  120. K_Global, Split_K);
  121. break;
  122. case 16:
  123. Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(
  124. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  125. K_Global, Split_K);
  126. break;
  127. case 32:
  128. Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(
  129. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  130. K_Global, Split_K);
  131. break;
  132. case 64:
  133. Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
  134. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  135. K_Global, Split_K);
  136. break;
  137. case 128:
  138. Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
  139. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  140. K_Global, Split_K);
  141. break;
  142. default:
  143. if (N_PowerOf2 % 128 != 0) {
  144. printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
  145. return cudaErrorUnknown;
  146. }
  147. Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
  148. stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
  149. K_Global, Split_K);
  150. break;
  151. }
  152. // Reduction for SplitK
  153. dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1,
  154. 1);
  155. dim3 BlockDim(WARP_SIZE, 1, 1);
  156. SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(
  157. C, Reduction_Workspace, M_Global, N_Global, Split_K);
  158. }
  159. return cudaGetLastError();
  160. }
  161. } // namespace aphrodite
  162. #include <torch/all.h>
  163. #include <ATen/ATen.h>
  164. #include <ATen/cuda/CUDAContext.h>
  165. #include <torch/library.h>
  166. // MODIFICATION NOTE: dtype of _weights is changed to uint8
  167. /*
  168. Computes FPx-FP16 GEMM (PyTorch interface).
  169. [Mathematical Formula]
  170. Standard definition of linear layer: Out = In * trans(W), where In, Out, and
  171. W are stored in row-major. After Equivalent transformation : trans(Out) =
  172. W * trans(In). Note that we do not perform "transpose" during runtime, we
  173. instead interpret the In/Out as column-major matrices when calling our CUDA
  174. kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half
  175. _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words
  176. contains 8 FPx weights. _scales: tensor of shape [OC]; //
  177. half splitK: splitting the MatMul problem along K dimension for higher GPU
  178. utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half
  179. */
  180. torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
  181. torch::Tensor _in_feats,
  182. torch::Tensor _weights,
  183. torch::Tensor _scales,
  184. int64_t splitK = 1) {
  185. const int64_t NBITS = 1 + EXPONENT + MANTISSA;
  186. int num_in_feats = _in_feats.size(0);
  187. int num_in_channels = _in_feats.size(1);
  188. int num_out_channels = _weights.size(0);
  189. TORCH_CHECK(num_in_channels % 64 == 0,
  190. "Expected in_features to be a multiple of 64, but received ",
  191. num_in_channels);
  192. TORCH_CHECK((num_in_channels / 8 * NBITS) ==
  193. _weights.size(1)); // Making sure the K dimension is matched.
  194. //
  195. int M = num_out_channels;
  196. int K = num_in_channels;
  197. int N = num_in_feats;
  198. // Input Tensors
  199. auto weight = reinterpret_cast<const uint4*>(
  200. _weights.data_ptr<uint8_t>()); // weights is [OC, IC] but in FP6.
  201. auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
  202. auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
  203. // Output Tensors
  204. auto options = torch::TensorOptions()
  205. .dtype(_in_feats.dtype())
  206. .device(_in_feats.device());
  207. at::Tensor _out_feats =
  208. torch::empty({num_in_feats, num_out_channels}, options);
  209. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  210. options =
  211. torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
  212. at::Tensor _workspace =
  213. torch::empty({splitK, num_in_feats, num_out_channels}, options);
  214. auto Reduction_Workspace = reinterpret_cast<float*>(
  215. _workspace.data_ptr<float>()); // Reduction_Workspace_Size = Split_K *
  216. // M_Global * N_Global * sizeof(fp32)
  217. // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default
  218. // stream (0) this fixes problem with CUDA graphs when used with
  219. // torch.compile()
  220. auto stream = at::cuda::getCurrentCUDAStream();
  221. /*
  222. The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit
  223. */
  224. // FP2
  225. if (EXPONENT == 1 && MANTISSA == 0)
  226. aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats,
  227. out_feats, M, N, K, Reduction_Workspace,
  228. splitK);
  229. // FP3
  230. else if (EXPONENT == 1 && MANTISSA == 1)
  231. aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats,
  232. out_feats, M, N, K, Reduction_Workspace,
  233. splitK);
  234. else if (EXPONENT == 2 && MANTISSA == 0)
  235. aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats,
  236. out_feats, M, N, K, Reduction_Workspace,
  237. splitK);
  238. // FP4
  239. else if (EXPONENT == 1 && MANTISSA == 2)
  240. aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats,
  241. out_feats, M, N, K, Reduction_Workspace,
  242. splitK);
  243. else if (EXPONENT == 3 && MANTISSA == 0)
  244. aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats,
  245. out_feats, M, N, K, Reduction_Workspace,
  246. splitK);
  247. else if (EXPONENT == 2 && MANTISSA == 1)
  248. aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats,
  249. out_feats, M, N, K, Reduction_Workspace,
  250. splitK);
  251. // FP5
  252. else if (EXPONENT == 1 && MANTISSA == 3)
  253. aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats,
  254. out_feats, M, N, K, Reduction_Workspace,
  255. splitK);
  256. else if (EXPONENT == 2 && MANTISSA == 2)
  257. aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats,
  258. out_feats, M, N, K, Reduction_Workspace,
  259. splitK);
  260. else if (EXPONENT == 3 && MANTISSA == 1)
  261. aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats,
  262. out_feats, M, N, K, Reduction_Workspace,
  263. splitK);
  264. else if (EXPONENT == 4 && MANTISSA == 0)
  265. aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats,
  266. out_feats, M, N, K, Reduction_Workspace,
  267. splitK);
  268. // FP6
  269. else if (EXPONENT == 1 && MANTISSA == 4)
  270. aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats,
  271. out_feats, M, N, K, Reduction_Workspace,
  272. splitK);
  273. else if (EXPONENT == 2 && MANTISSA == 3)
  274. aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats,
  275. out_feats, M, N, K, Reduction_Workspace,
  276. splitK);
  277. else if (EXPONENT == 3 && MANTISSA == 2)
  278. aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats,
  279. out_feats, M, N, K, Reduction_Workspace,
  280. splitK);
  281. else if (EXPONENT == 4 && MANTISSA == 1)
  282. aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats,
  283. out_feats, M, N, K, Reduction_Workspace,
  284. splitK);
  285. else if (EXPONENT == 5 && MANTISSA == 0)
  286. aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats,
  287. out_feats, M, N, K, Reduction_Workspace,
  288. splitK);
  289. // FP7
  290. else if (EXPONENT == 1 && MANTISSA == 5)
  291. aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats,
  292. out_feats, M, N, K, Reduction_Workspace,
  293. splitK);
  294. else if (EXPONENT == 2 && MANTISSA == 4)
  295. aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats,
  296. out_feats, M, N, K, Reduction_Workspace,
  297. splitK);
  298. else if (EXPONENT == 3 && MANTISSA == 3)
  299. aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats,
  300. out_feats, M, N, K, Reduction_Workspace,
  301. splitK);
  302. else if (EXPONENT == 4 && MANTISSA == 2)
  303. aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats,
  304. out_feats, M, N, K, Reduction_Workspace,
  305. splitK);
  306. else if (EXPONENT == 5 && MANTISSA == 1)
  307. aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats,
  308. out_feats, M, N, K, Reduction_Workspace,
  309. splitK);
  310. else
  311. TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA,
  312. " is not supported.");
  313. return _out_feats;
  314. }