reduction.cuh 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #pragma once
  21. #include "cuda_compat.h"
  22. namespace aphrodite {
  23. namespace detail {
  24. template <typename T>
  25. __inline__ __device__ T _max(T a, T b) {
  26. return max(a, b);
  27. }
  28. template <typename T>
  29. __inline__ __device__ T _sum(T a, T b) {
  30. return a + b;
  31. }
  32. } // namespace detail
  33. template <typename T>
  34. using ReduceFnType = T (*)(T, T);
  35. // Helper function to return the next largest power of 2
  36. static constexpr int _nextPow2(unsigned int num) {
  37. if (num <= 1) return num;
  38. #if defined(_MSC_VER) && !defined(__clang__) // MSVC without Clang
  39. // Decrement n (to handle cases when n itself is a power of 2)
  40. num--;
  41. // Set all bits after the first set bit
  42. num |= num >> 1;
  43. num |= num >> 2;
  44. num |= num >> 4;
  45. num |= num >> 8;
  46. num |= num >> 16;
  47. // Add 1 to get the next power of 2
  48. return num + 1;
  49. #else // GCC, Clang, or other compilers with __builtin_clz
  50. return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
  51. #endif
  52. }
  53. template <typename T, int numLanes = WARP_SIZE>
  54. __inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
  55. static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
  56. "numLanes is not a positive power of 2!");
  57. static_assert(numLanes <= WARP_SIZE);
  58. #pragma unroll
  59. for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
  60. val = fn(val, APHRODITE_SHFL_XOR_SYNC(val, mask));
  61. return val;
  62. }
  63. template <typename T, int maxBlockSize = 1024>
  64. __inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
  65. static_assert(maxBlockSize <= 1024);
  66. if constexpr (maxBlockSize > WARP_SIZE) {
  67. val = warpReduce<T>(val, fn);
  68. // Calculates max number of lanes that need to participate in the last
  69. // warpReduce
  70. constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
  71. static __shared__ T shared[maxActiveLanes];
  72. int lane = threadIdx.x % WARP_SIZE;
  73. int wid = threadIdx.x / WARP_SIZE;
  74. if (lane == 0) shared[wid] = val;
  75. __syncthreads();
  76. val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
  77. : (T)(0.0f);
  78. val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
  79. } else {
  80. // A single warpReduce is equal to blockReduce
  81. val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
  82. }
  83. return val;
  84. }
  85. template <typename T, int maxBlockSize = 1024>
  86. __inline__ __device__ T blockReduceMax(T val) {
  87. return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
  88. }
  89. template <typename T, int maxBlockSize = 1024>
  90. __inline__ __device__ T blockReduceSum(T val) {
  91. return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
  92. }
  93. } // namespace aphrodite