fast_hadamard_transform_common.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda_bf16.h>
  6. #include <cuda_fp16.h>
  7. #define FULL_MASK 0xffffffff
  8. ////////////////////////////////////////////////////////////////////////////////////////////////////
  9. struct uint8 {
  10. uint4 u;
  11. uint4 v;
  12. };
  13. template <int BYTES>
  14. struct BytesToType {};
  15. template <>
  16. struct BytesToType<32> {
  17. using Type = uint8;
  18. static_assert(sizeof(Type) == 32);
  19. };
  20. template <>
  21. struct BytesToType<16> {
  22. using Type = uint4;
  23. static_assert(sizeof(Type) == 16);
  24. };
  25. template <>
  26. struct BytesToType<8> {
  27. using Type = uint64_t;
  28. static_assert(sizeof(Type) == 8);
  29. };
  30. template <>
  31. struct BytesToType<4> {
  32. using Type = uint32_t;
  33. static_assert(sizeof(Type) == 4);
  34. };
  35. template <>
  36. struct BytesToType<2> {
  37. using Type = uint16_t;
  38. static_assert(sizeof(Type) == 2);
  39. };
  40. template <>
  41. struct BytesToType<1> {
  42. using Type = uint8_t;
  43. static_assert(sizeof(Type) == 1);
  44. };
  45. ////////////////////////////////////////////////////////////////////////////////////////////////////
  46. template <typename T>
  47. struct SumOp {
  48. __device__ inline T operator()(T const& x, T const& y) { return x + y; }
  49. };
  50. template <int THREADS>
  51. struct Allreduce {
  52. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  53. template <typename T, typename Operator>
  54. static __device__ inline T run(T x, Operator& op) {
  55. constexpr int OFFSET = THREADS / 2;
  56. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  57. return Allreduce<OFFSET>::run(x, op);
  58. }
  59. };
  60. template <>
  61. struct Allreduce<2> {
  62. template <typename T, typename Operator>
  63. static __device__ inline T run(T x, Operator& op) {
  64. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  65. return x;
  66. }
  67. };
  68. ////////////////////////////////////////////////////////////////////////////////////////////////////
  69. // https://stackoverflow.com/questions/35311711/whats-the-right-way-to-compute-integral-base-2-logarithms-at-compile-time
  70. constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; }
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. template <int kLogN, int kNChunks>
  73. __device__ __forceinline__ void hadamard_mult_thread(
  74. double x[kNChunks][1 << kLogN]) {
  75. constexpr int N = 1 << kLogN;
  76. #pragma unroll
  77. for (int i = 0; i < kLogN; ++i) {
  78. const int stride = 1 << i;
  79. #pragma unroll
  80. for (int j = 0; j < N / 2; ++j) {
  81. const int lo = j & (stride - 1);
  82. const int idx = (j - lo) * 2 + lo;
  83. #pragma unroll
  84. for (int c = 0; c < kNChunks; ++c) {
  85. const float a = x[c][idx];
  86. const float b = x[c][idx + stride];
  87. x[c][idx] = a + b;
  88. x[c][idx + stride] = a - b;
  89. }
  90. }
  91. }
  92. }
  93. template <int kLogWarpSize, int kStepStart, int kNChunks, int kNItems>
  94. __device__ __forceinline__ void hadamard_mult_warp(
  95. double x[kNChunks][kNItems]) {
  96. constexpr int N = 1 << kLogWarpSize;
  97. int lane_id = threadIdx.x % N;
  98. #pragma unroll
  99. for (int step = kStepStart; step < kLogWarpSize; ++step) {
  100. const int lane_mask = 1 << step;
  101. const float sign = (lane_id & lane_mask) ? -1.f : 1.f;
  102. #pragma unroll
  103. for (int c = 0; c < kNChunks; ++c) {
  104. #pragma unroll
  105. for (int i = 0; i < kNItems; ++i) {
  106. float x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask);
  107. x[c][i] = sign * x[c][i] + x_val_other;
  108. }
  109. }
  110. }
  111. }
  112. ////////////////////////////////////////////////////////////////////////////////////////////////////
  113. template <int kNChunks, int kNElts, typename input_t>
  114. inline __device__ void load_input(input_t* x, double x_vals[kNChunks][kNElts],
  115. int64_t dim) {
  116. using vec_t = typename BytesToType<sizeof(input_t) * kNElts>::Type;
  117. input_t x_vals_load[kNChunks][kNElts] = {0};
  118. #pragma unroll
  119. for (int c = 0; c < kNChunks; ++c) {
  120. if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
  121. reinterpret_cast<vec_t*>(x_vals_load)[c] =
  122. reinterpret_cast<const vec_t*>(x)[c * blockDim.x + threadIdx.x];
  123. }
  124. }
  125. #pragma unroll
  126. for (int c = 0; c < kNChunks; ++c) {
  127. #pragma unroll
  128. for (int i = 0; i < kNElts; ++i) {
  129. x_vals[c][i] = float(x_vals_load[c][i]);
  130. }
  131. }
  132. }
  133. template <int kNChunks, int kNElts, typename output_t>
  134. inline __device__ void store_output(output_t* out,
  135. double out_vals[kNChunks][kNElts],
  136. int64_t dim, double scale = 1.f) {
  137. using vec_t = typename BytesToType<sizeof(output_t) * kNElts>::Type;
  138. output_t out_vals_store[kNChunks][kNElts];
  139. #pragma unroll
  140. for (int c = 0; c < kNChunks; ++c) {
  141. #pragma unroll
  142. for (int i = 0; i < kNElts; ++i) {
  143. out_vals_store[c][i] = out_vals[c][i] * scale;
  144. }
  145. }
  146. #pragma unroll
  147. for (int c = 0; c < kNChunks; ++c) {
  148. if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
  149. reinterpret_cast<vec_t*>(out)[c * blockDim.x + threadIdx.x] =
  150. reinterpret_cast<const vec_t*>(out_vals_store)[c];
  151. }
  152. }
  153. }
  154. ////////////////////////////////////////////////////////////////////////////////////////////////////
  155. // Pre=true means the exchange before the hadamard_mult_warp, Pre=false means
  156. // after.
  157. template <int kNChunks, int kChunksPerExchange, int kNElts, int kWarpSize,
  158. int kNWarps, bool Pre, typename vec_t>
  159. inline __device__ void exchange_smem_pre(double x_vals[kNChunks][kNElts],
  160. vec_t* smem) {
  161. constexpr int kNThreads = kWarpSize * kNWarps;
  162. constexpr int kNExchangePerVec = kNElts / (sizeof(vec_t) / sizeof(float));
  163. const int warp_id = threadIdx.x / kWarpSize;
  164. const int lane_id = threadIdx.x % kWarpSize;
  165. const int row_t = threadIdx.x % kNWarps;
  166. const int col_t = threadIdx.x / kNWarps;
  167. // We use the XOR swizzle trick (new_col = col ^ row) to avoid / reduce smem
  168. // bank conflicts.
  169. #pragma unroll
  170. for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) {
  171. __syncthreads();
  172. #pragma unroll
  173. for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
  174. #pragma unroll
  175. for (int r = 0; r < kNExchangePerVec; ++r) {
  176. smem[(c1 * kNExchangePerVec + r) * kNThreads +
  177. (Pre ? warp_id * kWarpSize + lane_id ^ warp_id
  178. : row_t * kWarpSize + col_t ^ row_t)] =
  179. reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r];
  180. }
  181. }
  182. __syncthreads();
  183. #pragma unroll
  184. for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
  185. #pragma unroll
  186. for (int r = 0; r < kNExchangePerVec; ++r) {
  187. reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r] =
  188. smem[(c1 * kNExchangePerVec + r) * kNThreads +
  189. (Pre ? row_t * kWarpSize + col_t ^ row_t
  190. : warp_id * kWarpSize + lane_id ^ warp_id)];
  191. }
  192. }
  193. }
  194. }