reduction.cuh 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #pragma once
  20. #include "cuda_compat.h"
  21. namespace aphrodite {
  22. template<typename T>
  23. __inline__ __device__ T warpReduceSum(T val) {
  24. #pragma unroll
  25. for (int mask = 16; mask > 0; mask >>= 1)
  26. val += APHRODITE_SHFL_XOR_SYNC(val, mask);
  27. return val;
  28. }
  29. /* Calculate the sum of all elements in a block */
  30. template<typename T>
  31. __inline__ __device__ T blockReduceSum(T val) {
  32. static __shared__ T shared[32];
  33. int lane = threadIdx.x & 0x1f;
  34. int wid = threadIdx.x >> 5;
  35. val = warpReduceSum<T>(val);
  36. if (lane == 0)
  37. shared[wid] = val;
  38. __syncthreads();
  39. // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
  40. // blockDim.x is not divided by 32
  41. val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
  42. val = warpReduceSum<T>(val);
  43. return val;
  44. }
  45. } // namespace aphrodite