reduction.cuh 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. /*
  2. * Adapted from
  3. * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
  4. * Copyright (c) 2023, The PygmalionAI team.
  5. * Copyright (c) 2023, The vLLM team.
  6. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  7. *
  8. * Licensed under the Apache License, Version 2.0 (the "License");
  9. * you may not use this file except in compliance with the License.
  10. * You may obtain a copy of the License at
  11. *
  12. * http://www.apache.org/licenses/LICENSE-2.0
  13. *
  14. * Unless required by applicable law or agreed to in writing, software
  15. * distributed under the License is distributed on an "AS IS" BASIS,
  16. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. * See the License for the specific language governing permissions and
  18. * limitations under the License.
  19. */
  20. #pragma once
  21. #include "cuda_compat.h"
  22. namespace aphrodite {
  23. namespace detail {
  24. template <typename T>
  25. __inline__ __device__ T _max(T a, T b) {
  26. return max(a, b);
  27. }
  28. template <typename T>
  29. __inline__ __device__ T _sum(T a, T b) {
  30. return a + b;
  31. }
  32. } // namespace detail
  33. template <typename T>
  34. using ReduceFnType = T (*)(T, T);
  35. // Helper function to return the next largest power of 2
  36. static constexpr int _nextPow2(unsigned int num) {
  37. if (num <= 1) return num;
  38. return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
  39. }
  40. template <typename T, int numLanes = WARP_SIZE>
  41. __inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
  42. static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
  43. "numLanes is not a positive power of 2!");
  44. static_assert(numLanes <= WARP_SIZE);
  45. #pragma unroll
  46. for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
  47. val = fn(val, APHRODITE_SHFL_XOR_SYNC(val, mask));
  48. return val;
  49. }
  50. template <typename T, int maxBlockSize = 1024>
  51. __inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
  52. static_assert(maxBlockSize <= 1024);
  53. if constexpr (maxBlockSize > WARP_SIZE) {
  54. val = warpReduce<T>(val, fn);
  55. // Calculates max number of lanes that need to participate in the last
  56. // warpReduce
  57. constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
  58. static __shared__ T shared[maxActiveLanes];
  59. int lane = threadIdx.x % WARP_SIZE;
  60. int wid = threadIdx.x / WARP_SIZE;
  61. if (lane == 0) shared[wid] = val;
  62. __syncthreads();
  63. val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
  64. : (T)(0.0f);
  65. val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
  66. } else {
  67. // A single warpReduce is equal to blockReduce
  68. val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
  69. }
  70. return val;
  71. }
  72. template <typename T, int maxBlockSize = 1024>
  73. __inline__ __device__ T blockReduceMax(T val) {
  74. return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
  75. }
  76. template <typename T, int maxBlockSize = 1024>
  77. __inline__ __device__ T blockReduceSum(T val) {
  78. return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
  79. }
  80. } // namespace aphrodite