reduction.cuh 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #pragma once
  20. #include "cuda_compat.h"
  21. namespace aphrodite {
  22. template<typename T, int numLanes = WARP_SIZE>
  23. __inline__ __device__ T warpReduceSum(T val) {
  24. static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
  25. "numLanes is not a positive power of 2!");
  26. static_assert(numLanes <= WARP_SIZE);
  27. #pragma unroll
  28. for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
  29. val += APHRODITE_SHFL_XOR_SYNC(val, mask);
  30. return val;
  31. }
  32. // Helper function to return the next largest power of 2
  33. static constexpr int _nextPow2(unsigned int num) {
  34. if (num <= 1) return num;
  35. return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
  36. }
  37. /* Calculate the sum of all elements in a block */
  38. template<typename T, int maxBlockSize = 1024>
  39. __inline__ __device__ T blockReduceSum(T val) {
  40. static_assert(maxBlockSize <= 1024);
  41. if constexpr (maxBlockSize > WARP_SIZE) {
  42. val = warpReduceSum<T>(val);
  43. // Calculates max number of lanes that need to participate in the last warpReduce
  44. constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
  45. static __shared__ T shared[maxActiveLanes];
  46. int lane = threadIdx.x % WARP_SIZE;
  47. int wid = threadIdx.x / WARP_SIZE;
  48. if (lane == 0)
  49. shared[wid] = val;
  50. __syncthreads();
  51. val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
  52. val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
  53. } else {
  54. // A single warpReduce is equal to blockReduce
  55. val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
  56. }
  57. return val;
  58. }
  59. } // namespace aphrodite