123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- // Copyright 2024 FP6-LLM authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- // This file is adapted from
- // https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
- #include "kernel_matmul.cuh"
- #include "kernel_reduction.cuh"
- #include <stdio.h>
- #include <assert.h>
- namespace aphrodite {
- template <typename TilingConfig, typename OutputDataType, int EXPONENT,
- int MANTISSA>
- static void Kernel_Ex(cudaStream_t stream, const uint4* Weight,
- const half* Scales, const half* B, OutputDataType* C,
- const size_t M_Global, const size_t N_Global,
- const size_t K_Global, int Split_K) {
- #ifdef DEBUG_MODE
- printf("\n");
- printf("Launcher.cu->Kernel_Ex():\n");
- printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global,
- Split_K);
- printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M,
- TilingConfig::TILE_K, TilingConfig::TILE_N);
- #endif
- static size_t SHMEM_SZ =
- max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE,
- TilingConfig::SMEM_SIZE_C_TILE);
- cudaFuncSetAttribute(
- QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>,
- cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
- size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
- size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
- dim3 GridDim(dimN, dimM, 1);
- dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
- //
- #ifdef DEBUG_MODE
- printf(
- "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, "
- "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
- GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z,
- SHMEM_SZ);
- printf("\n");
- #endif
- QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>
- <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global,
- N_Global, K_Global, Split_K);
- }
- template <int EXPONENT, int MANTISSA>
- cudaError_t fpx_linear_kernel(
- cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B,
- half* C, const size_t M_Global, const size_t N_Global,
- const size_t K_Global,
- float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K *
- // M_Global * N_Global * sizeof(fp32)
- int Split_K) {
- assert(M_Global % 256 == 0);
- assert(K_Global % 64 == 0);
- assert(N_Global > 0);
- // Work around to support more N shapes:
- size_t N_PowerOf2;
- if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8;
- if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16;
- if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32;
- if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64;
- if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128;
- if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128;
- if (Split_K == 1) {
- switch (N_PowerOf2) {
- case 8:
- Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- case 16:
- Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- case 32:
- Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- case 64:
- Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- case 128:
- Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- default:
- if (N_PowerOf2 % 128 != 0) {
- printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
- return cudaErrorUnknown;
- }
- Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
- Split_K);
- break;
- }
- } else {
- switch (N_PowerOf2) {
- case 8:
- Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- case 16:
- Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- case 32:
- Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- case 64:
- Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- case 128:
- Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- default:
- if (N_PowerOf2 % 128 != 0) {
- printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
- return cudaErrorUnknown;
- }
- Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
- stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
- K_Global, Split_K);
- break;
- }
- // Reduction for SplitK
- dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1,
- 1);
- dim3 BlockDim(WARP_SIZE, 1, 1);
- SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(
- C, Reduction_Workspace, M_Global, N_Global, Split_K);
- }
- return cudaGetLastError();
- }
- } // namespace aphrodite
- #include <torch/all.h>
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <torch/library.h>
- // MODIFICATION NOTE: dtype of _weights is changed to uint8
- /*
- Computes FPx-FP16 GEMM (PyTorch interface).
- [Mathematical Formula]
- Standard definition of linear layer: Out = In * trans(W), where In, Out, and
- W are stored in row-major. After Equivalent transformation : trans(Out) =
- W * trans(In). Note that we do not perform "transpose" during runtime, we
- instead interpret the In/Out as column-major matrices when calling our CUDA
- kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half
- _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words
- contains 8 FPx weights. _scales: tensor of shape [OC]; //
- half splitK: splitting the MatMul problem along K dimension for higher GPU
- utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half
- */
- torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
- torch::Tensor _in_feats,
- torch::Tensor _weights,
- torch::Tensor _scales,
- int64_t splitK = 1) {
- const int64_t NBITS = 1 + EXPONENT + MANTISSA;
- int num_in_feats = _in_feats.size(0);
- int num_in_channels = _in_feats.size(1);
- int num_out_channels = _weights.size(0);
- TORCH_CHECK(num_in_channels % 64 == 0,
- "Expected in_features to be a multiple of 64, but received ",
- num_in_channels);
- TORCH_CHECK((num_in_channels / 8 * NBITS) ==
- _weights.size(1)); // Making sure the K dimension is matched.
- //
- int M = num_out_channels;
- int K = num_in_channels;
- int N = num_in_feats;
- // Input Tensors
- auto weight = reinterpret_cast<const uint4*>(
- _weights.data_ptr<uint8_t>()); // weights is [OC, IC] but in FP6.
- auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
- auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
- // Output Tensors
- auto options = torch::TensorOptions()
- .dtype(_in_feats.dtype())
- .device(_in_feats.device());
- at::Tensor _out_feats =
- torch::empty({num_in_feats, num_out_channels}, options);
- auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
- options =
- torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
- at::Tensor _workspace =
- torch::empty({splitK, num_in_feats, num_out_channels}, options);
- auto Reduction_Workspace = reinterpret_cast<float*>(
- _workspace.data_ptr<float>()); // Reduction_Workspace_Size = Split_K *
- // M_Global * N_Global * sizeof(fp32)
- // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default
- // stream (0) this fixes problem with CUDA graphs when used with
- // torch.compile()
- auto stream = at::cuda::getCurrentCUDAStream();
- /*
- The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit
- */
- // FP2
- if (EXPONENT == 1 && MANTISSA == 0)
- aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- // FP3
- else if (EXPONENT == 1 && MANTISSA == 1)
- aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 2 && MANTISSA == 0)
- aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- // FP4
- else if (EXPONENT == 1 && MANTISSA == 2)
- aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 3 && MANTISSA == 0)
- aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 2 && MANTISSA == 1)
- aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- // FP5
- else if (EXPONENT == 1 && MANTISSA == 3)
- aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 2 && MANTISSA == 2)
- aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 3 && MANTISSA == 1)
- aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 4 && MANTISSA == 0)
- aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- // FP6
- else if (EXPONENT == 1 && MANTISSA == 4)
- aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 2 && MANTISSA == 3)
- aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 3 && MANTISSA == 2)
- aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 4 && MANTISSA == 1)
- aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 5 && MANTISSA == 0)
- aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- // FP7
- else if (EXPONENT == 1 && MANTISSA == 5)
- aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 2 && MANTISSA == 4)
- aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 3 && MANTISSA == 3)
- aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 4 && MANTISSA == 2)
- aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else if (EXPONENT == 5 && MANTISSA == 1)
- aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats,
- out_feats, M, N, K, Reduction_Workspace,
- splitK);
- else
- TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA,
- " is not supported.");
- return _out_feats;
- }
|