fast_hadamard_transform_cuda.cu 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. // #pragma once
  5. #include <c10/util/BFloat16.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  8. #include "fast_hadamard_transform.h"
  9. #include "fast_hadamard_transform_common.h"
  10. #include "fast_hadamard_transform_special.h"
  11. #include "static_switch.h"
  12. template<int kNThreads_, int kLogN_, typename input_t_>
  13. struct fast_hadamard_transform_kernel_traits {
  14. using input_t = input_t_;
  15. static constexpr int kNThreads = kNThreads_;
  16. static constexpr int kLogN = kLogN_;
  17. static constexpr int N = 1 << kLogN;
  18. static constexpr int kNBytes = sizeof(input_t);
  19. static_assert(kNBytes == 2 || kNBytes == 4);
  20. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  21. // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
  22. // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
  23. static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
  24. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  25. static constexpr int kNChunks = N / (kNElts * kNThreads);
  26. // We don't want to use more than 32 KB of shared memory.
  27. static constexpr int kSmemExchangeSize = std::min(N * 4, 32 * 1024);
  28. static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
  29. static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
  30. static constexpr int kSmemSize = kSmemExchangeSize;
  31. };
  32. template<int kNThreads_, int kLogN_, typename input_t_>
  33. struct fast_hadamard_transform_12N_kernel_traits {
  34. using input_t = input_t_;
  35. static constexpr int kNThreads = kNThreads_;
  36. static constexpr int kLogN = kLogN_;
  37. static constexpr int N = (1 << kLogN) * 12;
  38. static_assert(N <= 12 * 1024, "fast_hadamard_transform_12 only supports dim <= 12288");
  39. static constexpr int kNBytes = sizeof(input_t);
  40. static_assert(kNBytes == 2 || kNBytes == 4);
  41. static constexpr int kNElts = 4;
  42. // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
  43. // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
  44. static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
  45. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  46. static constexpr int kNChunks = N / (kNElts * kNThreads);
  47. static_assert(kNChunks == 12);
  48. // We don't want to use more than 24 KB of shared memory.
  49. static constexpr int kSmemExchangeSize = std::min(N * 4, 24 * 1024);
  50. static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
  51. static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
  52. static constexpr int kSmemSize = kSmemExchangeSize;
  53. };
  54. template<int kNThreads_, int kLogN_, typename input_t_>
  55. struct fast_hadamard_transform_20N_kernel_traits {
  56. using input_t = input_t_;
  57. static constexpr int kNThreads = kNThreads_;
  58. static constexpr int kLogN = kLogN_;
  59. static constexpr int N = (1 << kLogN) * 20;
  60. static_assert(N <= 20 * 1024, "fast_hadamard_transform_20 only supports dim <= 20480");
  61. static constexpr int kNBytes = sizeof(input_t);
  62. static_assert(kNBytes == 2 || kNBytes == 4);
  63. static constexpr int kNElts = 4;
  64. // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
  65. // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
  66. static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
  67. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  68. static constexpr int kNChunks = N / (kNElts * kNThreads);
  69. static_assert(kNChunks == 20);
  70. // We don't want to use more than 40 KB of shared memory.
  71. static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
  72. static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
  73. static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
  74. static constexpr int kSmemSize = kSmemExchangeSize;
  75. };
  76. template<int kNThreads_, int kLogN_, typename input_t_>
  77. struct fast_hadamard_transform_28N_kernel_traits {
  78. using input_t = input_t_;
  79. static constexpr int kNThreads = kNThreads_;
  80. static constexpr int kLogN = kLogN_;
  81. static constexpr int N = (1 << kLogN) * 28;
  82. static_assert(N <= 28 * 1024, "fast_hadamard_transform_28 only supports dim <= 28672");
  83. static constexpr int kNBytes = sizeof(input_t);
  84. static_assert(kNBytes == 2 || kNBytes == 4);
  85. static constexpr int kNElts = 4;
  86. // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
  87. // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
  88. static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
  89. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  90. static constexpr int kNChunks = N / (kNElts * kNThreads);
  91. static_assert(kNChunks == 28);
  92. // We don't want to use more than 28 KB of shared memory.
  93. static constexpr int kSmemExchangeSize = std::min(N * 4, 28 * 1024);
  94. static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
  95. static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
  96. static constexpr int kSmemSize = kSmemExchangeSize;
  97. };
  98. template <int kNChunks>
  99. __device__ __forceinline__ void hadamard_mult_thread_chunk_12(float x[kNChunks][12]) {
  100. #pragma unroll
  101. for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_12(x[c]); }
  102. }
  103. template <int kNChunks>
  104. __device__ __forceinline__ void hadamard_mult_thread_chunk_20(float x[kNChunks][20]) {
  105. #pragma unroll
  106. for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_20(x[c]); }
  107. }
  108. template <int kNChunks>
  109. __device__ __forceinline__ void hadamard_mult_thread_chunk_28(float x[kNChunks][28]) {
  110. #pragma unroll
  111. for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_28(x[c]); }
  112. }
  113. template<typename Ktraits>
  114. __global__ __launch_bounds__(Ktraits::kNThreads)
  115. void fast_hadamard_transform_kernel(HadamardParamsBase params) {
  116. constexpr int kNThreads = Ktraits::kNThreads;
  117. constexpr int kNElts = Ktraits::kNElts;
  118. constexpr int kNExchangePerVec = Ktraits::kNExchangePerVec;
  119. constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
  120. constexpr int kNChunks = Ktraits::kNChunks;
  121. using input_t = typename Ktraits::input_t;
  122. using vec_t = typename Ktraits::vec_t;
  123. constexpr int kLogNElts = cilog2(Ktraits::kNElts);
  124. static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2");
  125. constexpr int kWarpSize = std::min(kNThreads, 32);
  126. constexpr int kLogWarpSize = cilog2(kWarpSize);
  127. static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2");
  128. constexpr int kNWarps = kNThreads / kWarpSize;
  129. constexpr int kLogNWarps = cilog2(kNWarps);
  130. static_assert(1 << kLogNWarps == kNWarps, "kNWarps must be a power of 2");
  131. constexpr int kLoadsPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNThreads);
  132. static_assert(kLoadsPerExchange * sizeof(vec_t) * kNThreads == Ktraits::kSmemExchangeSize, "kSmemExchangeSize should be a power of 2");
  133. static_assert(kNExchangeRounds * kLoadsPerExchange * sizeof(vec_t) == kNChunks * kNElts * sizeof(float));
  134. constexpr int kChunksPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNExchangePerVec * kNThreads);
  135. static_assert(kChunksPerExchange * sizeof(vec_t) * kNExchangePerVec * kNThreads == Ktraits::kSmemExchangeSize);
  136. constexpr int kNExchanges = kNChunks / kChunksPerExchange;
  137. static_assert(kNExchanges * kChunksPerExchange == kNChunks);
  138. // Shared memory.
  139. extern __shared__ char smem_[];
  140. vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_);
  141. const int batch_id = blockIdx.x;
  142. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride;
  143. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride;
  144. float x_vals[kNChunks][kNElts];
  145. load_input<kNChunks, kNElts, input_t>(x, x_vals, params.dim);
  146. hadamard_mult_thread<kLogNElts, kNChunks>(x_vals);
  147. hadamard_mult_warp<kLogWarpSize, 0, kNChunks, kNElts>(x_vals);
  148. if constexpr (kNWarps > 1) {
  149. exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange);
  150. hadamard_mult_warp<kLogNWarps, 0, kNChunks, kNElts>(x_vals);
  151. exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
  152. }
  153. if constexpr (kNChunks > 1) {
  154. float x_vals_transposed[kNElts][kNChunks];
  155. #pragma unroll
  156. for (int c = 0; c < kNChunks; ++c) {
  157. #pragma unroll
  158. for (int i = 0; i < kNElts; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
  159. }
  160. if constexpr (kNChunks == 12) {
  161. hadamard_mult_thread_chunk_12<kNElts>(x_vals_transposed);
  162. } else if constexpr (kNChunks == 20) {
  163. hadamard_mult_thread_chunk_20<kNElts>(x_vals_transposed);
  164. } else if constexpr (kNChunks == 28) {
  165. hadamard_mult_thread_chunk_28<kNElts>(x_vals_transposed);
  166. } else {
  167. constexpr int kLogNChunks = cilog2(kNChunks);
  168. static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
  169. hadamard_mult_thread<kLogNChunks, kNElts>(x_vals_transposed);
  170. }
  171. #pragma unroll
  172. for (int c = 0; c < kNChunks; ++c) {
  173. #pragma unroll
  174. for (int i = 0; i < kNElts; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
  175. }
  176. }
  177. store_output<kNChunks, kNElts, input_t>(out, x_vals, params.dim, params.scale);
  178. }
  179. template<int kNThreads, int kLogN, typename input_t>
  180. void fast_hadamard_transform_launch(HadamardParamsBase &params, cudaStream_t stream) {
  181. using Ktraits = fast_hadamard_transform_kernel_traits<kNThreads, kLogN, input_t>;
  182. constexpr int kSmemSize = Ktraits::kSmemSize;
  183. dim3 grid(params.batch);
  184. auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
  185. if (kSmemSize >= 48 * 1024) {
  186. C10_CUDA_CHECK(cudaFuncSetAttribute(
  187. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  188. }
  189. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  190. C10_CUDA_KERNEL_LAUNCH_CHECK();
  191. }
  192. template<typename input_t>
  193. void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream) {
  194. if (params.log_N == 3) {
  195. fast_hadamard_transform_launch<1, 3, input_t>(params, stream);
  196. } else if (params.log_N == 4) {
  197. fast_hadamard_transform_launch<2, 4, input_t>(params, stream);
  198. } else if (params.log_N == 5) {
  199. fast_hadamard_transform_launch<4, 5, input_t>(params, stream);
  200. } else if (params.log_N == 6) {
  201. fast_hadamard_transform_launch<8, 6, input_t>(params, stream);
  202. } else if (params.log_N == 7) {
  203. fast_hadamard_transform_launch<16, 7, input_t>(params, stream);
  204. } else if (params.log_N == 8) {
  205. fast_hadamard_transform_launch<32, 8, input_t>(params, stream);
  206. } else if (params.log_N == 9) {
  207. fast_hadamard_transform_launch<32, 9, input_t>(params, stream);
  208. } else if (params.log_N == 10) {
  209. fast_hadamard_transform_launch<128, 10, input_t>(params, stream);
  210. } else if (params.log_N == 11) {
  211. fast_hadamard_transform_launch<256, 11, input_t>(params, stream);
  212. } else if (params.log_N == 12) {
  213. fast_hadamard_transform_launch<256, 12, input_t>(params, stream);
  214. } else if (params.log_N == 13) {
  215. fast_hadamard_transform_launch<256, 13, input_t>(params, stream);
  216. } else if (params.log_N == 14) {
  217. fast_hadamard_transform_launch<256, 14, input_t>(params, stream);
  218. } else if (params.log_N == 15) {
  219. fast_hadamard_transform_launch<256, 15, input_t>(params, stream);
  220. }
  221. }
  222. template<int kNThreads, int kLogN, typename input_t>
  223. void fast_hadamard_transform_12N_launch(HadamardParamsBase &params, cudaStream_t stream) {
  224. using Ktraits = fast_hadamard_transform_20N_kernel_traits<kNThreads, kLogN, input_t>;
  225. constexpr int kSmemSize = Ktraits::kSmemSize;
  226. dim3 grid(params.batch);
  227. auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
  228. if (kSmemSize >= 48 * 1024) {
  229. C10_CUDA_CHECK(cudaFuncSetAttribute(
  230. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  231. }
  232. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  233. C10_CUDA_KERNEL_LAUNCH_CHECK();
  234. }
  235. template<typename input_t>
  236. void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
  237. if (params.log_N == 2) {
  238. fast_hadamard_transform_12N_launch<1, 2, input_t>(params, stream);
  239. } else if (params.log_N == 2) {
  240. fast_hadamard_transform_12N_launch<2, 3, input_t>(params, stream);
  241. } else if (params.log_N == 4) {
  242. fast_hadamard_transform_12N_launch<4, 4, input_t>(params, stream);
  243. } else if (params.log_N == 5) {
  244. fast_hadamard_transform_12N_launch<8, 5, input_t>(params, stream);
  245. } else if (params.log_N == 6) {
  246. fast_hadamard_transform_12N_launch<16, 6, input_t>(params, stream);
  247. } else if (params.log_N == 7) {
  248. fast_hadamard_transform_12N_launch<32, 7, input_t>(params, stream);
  249. } else if (params.log_N == 8) {
  250. fast_hadamard_transform_12N_launch<64, 8, input_t>(params, stream);
  251. } else if (params.log_N == 9) {
  252. fast_hadamard_transform_12N_launch<128, 9, input_t>(params, stream);
  253. } else if (params.log_N == 10) {
  254. fast_hadamard_transform_12N_launch<256, 10, input_t>(params, stream);
  255. }
  256. }
  257. template<int kNThreads, int kLogN, typename input_t>
  258. void fast_hadamard_transform_20N_launch(HadamardParamsBase &params, cudaStream_t stream) {
  259. using Ktraits = fast_hadamard_transform_20N_kernel_traits<kNThreads, kLogN, input_t>;
  260. constexpr int kSmemSize = Ktraits::kSmemSize;
  261. dim3 grid(params.batch);
  262. auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
  263. if (kSmemSize >= 48 * 1024) {
  264. C10_CUDA_CHECK(cudaFuncSetAttribute(
  265. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  266. }
  267. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  268. C10_CUDA_KERNEL_LAUNCH_CHECK();
  269. }
  270. template<typename input_t>
  271. void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
  272. if (params.log_N == 2) {
  273. fast_hadamard_transform_20N_launch<1, 2, input_t>(params, stream);
  274. } else if (params.log_N == 2) {
  275. fast_hadamard_transform_20N_launch<2, 3, input_t>(params, stream);
  276. } else if (params.log_N == 4) {
  277. fast_hadamard_transform_20N_launch<4, 4, input_t>(params, stream);
  278. } else if (params.log_N == 5) {
  279. fast_hadamard_transform_20N_launch<8, 5, input_t>(params, stream);
  280. } else if (params.log_N == 6) {
  281. fast_hadamard_transform_20N_launch<16, 6, input_t>(params, stream);
  282. } else if (params.log_N == 7) {
  283. fast_hadamard_transform_20N_launch<32, 7, input_t>(params, stream);
  284. } else if (params.log_N == 8) {
  285. fast_hadamard_transform_20N_launch<64, 8, input_t>(params, stream);
  286. } else if (params.log_N == 9) {
  287. fast_hadamard_transform_20N_launch<128, 9, input_t>(params, stream);
  288. } else if (params.log_N == 10) {
  289. fast_hadamard_transform_20N_launch<256, 10, input_t>(params, stream);
  290. }
  291. }
  292. template<int kNThreads, int kLogN, typename input_t>
  293. void fast_hadamard_transform_28N_launch(HadamardParamsBase &params, cudaStream_t stream) {
  294. using Ktraits = fast_hadamard_transform_28N_kernel_traits<kNThreads, kLogN, input_t>;
  295. constexpr int kSmemSize = Ktraits::kSmemSize;
  296. dim3 grid(params.batch);
  297. auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
  298. if (kSmemSize >= 48 * 1024) {
  299. C10_CUDA_CHECK(cudaFuncSetAttribute(
  300. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  301. }
  302. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  303. C10_CUDA_KERNEL_LAUNCH_CHECK();
  304. }
  305. template<typename input_t>
  306. void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
  307. if (params.log_N == 2) {
  308. fast_hadamard_transform_28N_launch<1, 2, input_t>(params, stream);
  309. } else if (params.log_N == 2) {
  310. fast_hadamard_transform_28N_launch<2, 3, input_t>(params, stream);
  311. } else if (params.log_N == 4) {
  312. fast_hadamard_transform_28N_launch<4, 4, input_t>(params, stream);
  313. } else if (params.log_N == 5) {
  314. fast_hadamard_transform_28N_launch<8, 5, input_t>(params, stream);
  315. } else if (params.log_N == 6) {
  316. fast_hadamard_transform_28N_launch<16, 6, input_t>(params, stream);
  317. } else if (params.log_N == 7) {
  318. fast_hadamard_transform_28N_launch<32, 7, input_t>(params, stream);
  319. } else if (params.log_N == 8) {
  320. fast_hadamard_transform_28N_launch<64, 8, input_t>(params, stream);
  321. } else if (params.log_N == 9) {
  322. fast_hadamard_transform_28N_launch<128, 9, input_t>(params, stream);
  323. } else if (params.log_N == 10) {
  324. fast_hadamard_transform_28N_launch<256, 10, input_t>(params, stream);
  325. }
  326. }
  327. template void fast_hadamard_transform_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
  328. template void fast_hadamard_transform_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
  329. template void fast_hadamard_transform_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
  330. template void fast_hadamard_transform_12N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
  331. template void fast_hadamard_transform_12N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
  332. template void fast_hadamard_transform_12N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
  333. template void fast_hadamard_transform_20N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
  334. template void fast_hadamard_transform_20N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
  335. template void fast_hadamard_transform_20N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
  336. template void fast_hadamard_transform_28N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
  337. template void fast_hadamard_transform_28N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
  338. template void fast_hadamard_transform_28N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);