// Copyright 2024 FP6-LLM authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // This file is adapted from // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" #include #include namespace aphrodite { template static void Kernel_Ex(cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { #ifdef DEBUG_MODE printf("\n"); printf("Launcher.cu->Kernel_Ex():\n"); printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); cudaFuncSetAttribute( QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); // #ifdef DEBUG_MODE printf( "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, " "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif QUANT_GEMM_Kernel <<>>(Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } template cudaError_t fpx_linear_kernel( cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B, half* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * // M_Global * N_Global * sizeof(fp32) int Split_K) { assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global > 0); // Work around to support more N shapes: size_t N_PowerOf2; if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; if (Split_K == 1) { switch (N_PowerOf2) { case 8: Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 16: Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 32: Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 64: Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } Kernel_Ex, half, EXPONENT, MANTISSA>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { case 8: Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 16: Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 32: Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 64: Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } Kernel_Ex, float, EXPONENT, MANTISSA>( stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); SplitK_Reduction<<>>( C, Reduction_Workspace, M_Global, N_Global, Split_K); } return cudaGetLastError(); } } // namespace aphrodite #include #include #include #include // MODIFICATION NOTE: dtype of _weights is changed to uint8 /* Computes FPx-FP16 GEMM (PyTorch interface). [Mathematical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. _scales: tensor of shape [OC]; // half splitK: splitting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA, torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int64_t splitK = 1) { const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; // Input Tensors auto weight = reinterpret_cast( _weights.data_ptr()); // weights is [OC, IC] but in FP6. auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); // Output Tensors auto options = torch::TensorOptions() .dtype(_in_feats.dtype()) .device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast( _workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * // M_Global * N_Global * sizeof(fp32) // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default // stream (0) this fixes problem with CUDA graphs when used with // torch.compile() auto stream = at::cuda::getCurrentCUDAStream(); /* The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit */ // FP2 if (EXPONENT == 1 && MANTISSA == 0) aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // FP3 else if (EXPONENT == 1 && MANTISSA == 1) aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 0) aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // FP4 else if (EXPONENT == 1 && MANTISSA == 2) aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 0) aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 1) aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // FP5 else if (EXPONENT == 1 && MANTISSA == 3) aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 2) aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 1) aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 4 && MANTISSA == 0) aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // FP6 else if (EXPONENT == 1 && MANTISSA == 4) aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 3) aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 2) aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 4 && MANTISSA == 1) aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 5 && MANTISSA == 0) aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // FP7 else if (EXPONENT == 1 && MANTISSA == 5) aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 4) aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 3) aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 4 && MANTISSA == 2) aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 5 && MANTISSA == 1) aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); return _out_feats; }