kernel_reduction.cuh 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is copied from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh
  17. /***************************************************************************
  18. * Copyright 2023 The FLash-LLM Authors. All rights reserved.
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. * http://www.apache.org/licenses/LICENSE-2.0
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an "AS IS" BASIS,
  25. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. ***************************************************************************/
  29. // Used for the reduction of result matrix if Split-K is used
  30. // Reduction_Workspace: (Split_K, M_Global, N_Global), column major
  31. // C: (M_Global, N_Global), column major
  32. // Each thread deals with 8 output elements, each elements is the sum of Split_K
  33. // elements
  34. // Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8
  35. // float_per_thread (256bit) -> 256 float per warp Write Global: Each
  36. // Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) ->
  37. // 256 half per warp
  38. // GridSize = (M_Global*N_Global) / 256
  39. #include <cuda.h>
  40. #include <cuda_fp16.h>
  41. #include <cuda_runtime.h>
  42. #define REDUCTION_ELEMENT_PER_THREADBLOCK 256
  43. #define HALF_PER_128BIT 8
  44. __global__ void SplitK_Reduction(half* C, float* Reduction_Workspace,
  45. size_t M_Global, size_t N_Global,
  46. int Split_K) {
  47. half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
  48. float* WARP_GPTR_R =
  49. Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
  50. half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT;
  51. float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT;
  52. // Initializing Thread-Local Results
  53. float Results[HALF_PER_128BIT];
  54. #pragma unroll
  55. for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f;
  56. // Reduction
  57. for (int i = 0; i < Split_K; i++) {
  58. #pragma unroll
  59. for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j];
  60. THREAD_GPTR_R += M_Global * N_Global;
  61. }
  62. // Writing to global memory
  63. #pragma unroll
  64. for (int i = 0; i < HALF_PER_128BIT; i++)
  65. THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
  66. }