fast_hadamard_transform_common.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda_bf16.h>
  6. #include <cuda_fp16.h>
  7. #define FULL_MASK 0xffffffff
  8. ////////////////////////////////////////////////////////////////////////////////////////////////////
  9. struct uint8 {
  10. uint4 u;
  11. uint4 v;
  12. };
  13. template<int BYTES> struct BytesToType {};
  14. template<>
  15. struct BytesToType<32> {
  16. using Type = uint8;
  17. static_assert(sizeof(Type) == 32);
  18. };
  19. template<> struct BytesToType<16> {
  20. using Type = uint4;
  21. static_assert(sizeof(Type) == 16);
  22. };
  23. template<> struct BytesToType<8> {
  24. using Type = uint64_t;
  25. static_assert(sizeof(Type) == 8);
  26. };
  27. template<> struct BytesToType<4> {
  28. using Type = uint32_t;
  29. static_assert(sizeof(Type) == 4);
  30. };
  31. template<> struct BytesToType<2> {
  32. using Type = uint16_t;
  33. static_assert(sizeof(Type) == 2);
  34. };
  35. template<> struct BytesToType<1> {
  36. using Type = uint8_t;
  37. static_assert(sizeof(Type) == 1);
  38. };
  39. ////////////////////////////////////////////////////////////////////////////////////////////////////
  40. template<typename T>
  41. struct SumOp {
  42. __device__ inline T operator()(T const & x, T const & y) { return x + y; }
  43. };
  44. template<int THREADS>
  45. struct Allreduce {
  46. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  47. template<typename T, typename Operator>
  48. static __device__ inline T run(T x, Operator &op) {
  49. constexpr int OFFSET = THREADS / 2;
  50. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  51. return Allreduce<OFFSET>::run(x, op);
  52. }
  53. };
  54. template<>
  55. struct Allreduce<2> {
  56. template<typename T, typename Operator>
  57. static __device__ inline T run(T x, Operator &op) {
  58. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  59. return x;
  60. }
  61. };
  62. ////////////////////////////////////////////////////////////////////////////////////////////////////
  63. // https://stackoverflow.com/questions/35311711/whats-the-right-way-to-compute-integral-base-2-logarithms-at-compile-time
  64. constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; }
  65. ////////////////////////////////////////////////////////////////////////////////////////////////////
  66. template<int kLogN, int kNChunks>
  67. __device__ __forceinline__ void hadamard_mult_thread(float x[kNChunks][1 << kLogN]) {
  68. constexpr int N = 1 << kLogN;
  69. #pragma unroll
  70. for (int i = 0; i < kLogN; ++i) {
  71. const int stride = 1 << i;
  72. #pragma unroll
  73. for (int j = 0; j < N / 2; ++j) {
  74. const int lo = j & (stride - 1);
  75. const int idx = (j - lo) * 2 + lo;
  76. #pragma unroll
  77. for (int c = 0; c < kNChunks; ++c) {
  78. const float a = x[c][idx];
  79. const float b = x[c][idx + stride];
  80. x[c][idx] = a + b;
  81. x[c][idx + stride] = a - b;
  82. }
  83. }
  84. }
  85. }
  86. template<int kLogWarpSize, int kStepStart, int kNChunks, int kNItems>
  87. __device__ __forceinline__ void hadamard_mult_warp(float x[kNChunks][kNItems]) {
  88. constexpr int N = 1 << kLogWarpSize;
  89. int lane_id = threadIdx.x % N;
  90. #pragma unroll
  91. for (int step = kStepStart; step < kLogWarpSize; ++step) {
  92. const int lane_mask = 1 << step;
  93. const float sign = (lane_id & lane_mask) ? -1.f : 1.f;
  94. #pragma unroll
  95. for (int c = 0; c < kNChunks; ++c) {
  96. #pragma unroll
  97. for (int i = 0; i < kNItems; ++i) {
  98. float x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask);
  99. x[c][i] = sign * x[c][i] + x_val_other;
  100. }
  101. }
  102. }
  103. }
  104. ////////////////////////////////////////////////////////////////////////////////////////////////////
  105. template <int kNChunks, int kNElts, typename input_t>
  106. inline __device__ void load_input(input_t *x, float x_vals[kNChunks][kNElts], int dim) {
  107. using vec_t = typename BytesToType<sizeof(input_t) * kNElts>::Type;
  108. input_t x_vals_load[kNChunks][kNElts] = {0};
  109. #pragma unroll
  110. for (int c = 0; c < kNChunks; ++c) {
  111. if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
  112. reinterpret_cast<vec_t*>(x_vals_load)[c] = reinterpret_cast<const vec_t*>(x)[c * blockDim.x + threadIdx.x];
  113. }
  114. }
  115. #pragma unroll
  116. for (int c = 0; c < kNChunks; ++c) {
  117. #pragma unroll
  118. for (int i = 0; i < kNElts; ++i) { x_vals[c][i] = float(x_vals_load[c][i]); }
  119. }
  120. }
  121. template <int kNChunks, int kNElts, typename output_t>
  122. inline __device__ void store_output(output_t *out, float out_vals[kNChunks][kNElts], int dim, float scale=1.f) {
  123. using vec_t = typename BytesToType<sizeof(output_t) * kNElts>::Type;
  124. output_t out_vals_store[kNChunks][kNElts];
  125. #pragma unroll
  126. for (int c = 0; c < kNChunks; ++c) {
  127. #pragma unroll
  128. for (int i = 0; i < kNElts; ++i) { out_vals_store[c][i] = out_vals[c][i] * scale; }
  129. }
  130. #pragma unroll
  131. for (int c = 0; c < kNChunks; ++c) {
  132. if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
  133. reinterpret_cast<vec_t*>(out)[c * blockDim.x + threadIdx.x] = reinterpret_cast<const vec_t*>(out_vals_store)[c];
  134. }
  135. }
  136. }
  137. ////////////////////////////////////////////////////////////////////////////////////////////////////
  138. // Pre=true means the exchange before the hadamard_mult_warp, Pre=false means after.
  139. template <int kNChunks, int kChunksPerExchange, int kNElts, int kWarpSize, int kNWarps, bool Pre, typename vec_t>
  140. inline __device__ void exchange_smem_pre(float x_vals[kNChunks][kNElts], vec_t *smem) {
  141. constexpr int kNThreads = kWarpSize * kNWarps;
  142. constexpr int kNExchangePerVec = kNElts / (sizeof(vec_t) / sizeof(float));
  143. const int warp_id = threadIdx.x / kWarpSize;
  144. const int lane_id = threadIdx.x % kWarpSize;
  145. const int row_t = threadIdx.x % kNWarps;
  146. const int col_t = threadIdx.x / kNWarps;
  147. // We use the XOR swizzle trick (new_col = col ^ row) to avoid / reduce smem bank conflicts.
  148. #pragma unroll
  149. for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) {
  150. __syncthreads();
  151. #pragma unroll
  152. for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
  153. #pragma unroll
  154. for (int r = 0; r < kNExchangePerVec; ++r) {
  155. smem[(c1 * kNExchangePerVec + r) * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] = reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r];
  156. }
  157. }
  158. __syncthreads();
  159. #pragma unroll
  160. for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
  161. #pragma unroll
  162. for (int r = 0; r < kNExchangePerVec; ++r) {
  163. reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r] = smem[(c1 * kNExchangePerVec + r) * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)];
  164. }
  165. }
  166. }
  167. }