// Copyright 2024 FP6-LLM authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // This file is copied from // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ***************************************************************************/ // Used for the reduction of result matrix if Split-K is used // Reduction_Workspace: (Split_K, M_Global, N_Global), column major // C: (M_Global, N_Global), column major // Each thread deals with 8 output elements, each elements is the sum of Split_K // elements // Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 // float_per_thread (256bit) -> 256 float per warp Write Global: Each // Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> // 256 half per warp // GridSize = (M_Global*N_Global) / 256 #include #include #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 #define HALF_PER_128BIT 8 __global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; // Reduction for (int i = 0; i < Split_K; i++) { #pragma unroll for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); }