ln_parallel_residual_fwd_kernels.cuh 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #pragma once
  2. #ifdef OLD_GENERATOR_PATH
  3. #include <ATen/CUDAGeneratorImpl.h>
  4. #else
  5. #include <ATen/cuda/CUDAGeneratorImpl.h>
  6. #endif
  7. #include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
  8. #include <curand_kernel.h>
  9. #include "ln.h"
  10. #include "ln_utils.cuh"
  11. #include "ln_kernel_traits.h"
  12. #include "static_switch.h"
  13. namespace layer_norm {
  14. template<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>
  15. __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
  16. void ln_parallel_residual_fwd_kernel(FwdParams params) {
  17. enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
  18. enum { WARPS_N = Ktraits::WARPS_N };
  19. enum { WARPS_M = Ktraits::WARPS_M };
  20. enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
  21. enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
  22. enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
  23. enum { LDGS = Ktraits::LDGS };
  24. enum { NUM_ELTS = Ktraits::NUM_ELTS };
  25. enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
  26. using input_t = typename Ktraits::input_t;
  27. using residual_t = typename Ktraits::residual_t;
  28. using output_t = typename Ktraits::output_t;
  29. using index_t = typename Ktraits::index_t;
  30. using compute_t = typename Ktraits::compute_t;
  31. using mask_t = typename Ktraits::mask_t;
  32. using Ivec = typename Ktraits::Ivec;
  33. using Rvec = typename Ktraits::Rvec;
  34. using Ovec = typename Ktraits::Ovec;
  35. using Wvec = typename Ktraits::Wvec;
  36. using Cvec = typename Ktraits::Cvec;
  37. using Mvec = typename Ktraits::Mvec;
  38. using Stats = typename Ktraits::Stats;
  39. using stats_t = typename Stats::stats_t;
  40. const bool has_residual = params.residual != nullptr;
  41. const bool has_x1 = params.x1 != nullptr;
  42. const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same<input_t, residual_t>::value);
  43. extern __shared__ char smem_[];
  44. const index_t tidx = threadIdx.x;
  45. const index_t bidn = blockIdx.x % CTAS_PER_ROW;
  46. const index_t bidm = blockIdx.x / CTAS_PER_ROW;
  47. const index_t lane = tidx % THREADS_PER_WARP;
  48. const index_t warp = tidx / THREADS_PER_WARP;
  49. const index_t warp_m = warp / WARPS_N;
  50. const index_t warp_n = warp % WARPS_N;
  51. const index_t r = bidm * ROWS_PER_CTA + warp_m;
  52. const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
  53. Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
  54. compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
  55. compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
  56. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
  57. curandStatePhilox4_32_10_t state;
  58. if (Is_dropout) {
  59. auto seeds = at::cuda::philox::unpack(params.philox_args);
  60. const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
  61. curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
  62. }
  63. const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
  64. Wvec gamma0[LDGS];
  65. Wvec beta0[LDGS];
  66. Wvec gamma1[LDGS];
  67. Wvec beta1[LDGS];
  68. index_t idx = c;
  69. #pragma unroll
  70. for( int it = 0; it < LDGS; it++ ) {
  71. if (Is_even_cols || (it < num_valid_ldgs)) {
  72. gamma0[it].load_from(params.gamma, idx);
  73. if (params.beta != nullptr) {
  74. beta0[it].load_from(params.beta, idx);
  75. } else {
  76. beta0[it].zero_();
  77. }
  78. if (!Tied_norm) {
  79. gamma1[it].load_from(params.gamma1, idx);
  80. if (params.beta1 != nullptr) {
  81. beta1[it].load_from(params.beta1, idx);
  82. } else {
  83. beta1[it].zero_();
  84. }
  85. }
  86. idx += VEC_COLS_PER_LDG;
  87. }
  88. }
  89. for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
  90. index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  91. compute_t xf[LDGS * NUM_ELTS];
  92. #pragma unroll
  93. for( int it = 0; it < LDGS; it++ ) {
  94. if (Is_even_cols || (it < num_valid_ldgs)) {
  95. Ivec x0;
  96. Ivec x1;
  97. Rvec residual;
  98. Rvec x;
  99. Mvec dmask0;
  100. Mvec dmask1;
  101. x0.load_from(params.x0, idx);
  102. if (has_x1) { x1.load_from(params.x1, idx); }
  103. if (has_residual) { residual.load_from(params.residual, idx); }
  104. #pragma unroll
  105. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  106. // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
  107. // the more efficient curand_uniform4.
  108. compute_t x_ij;
  109. mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
  110. if (Is_dropout) { dmask0.data.elt[jt] = keep0; }
  111. compute_t x0_ij = compute_t(x0.data.elt[jt]);
  112. x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
  113. if (has_x1) {
  114. mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
  115. if (Is_dropout) { dmask1.data.elt[jt] = keep1; }
  116. compute_t x1_ij = compute_t(x1.data.elt[jt]);
  117. x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f;
  118. x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij;
  119. } else {
  120. x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
  121. }
  122. if (save_x) { x.data.elt[jt] = x_ij; }
  123. xf[it * NUM_ELTS + jt] = x_ij;
  124. }
  125. if (save_x) { x.store_to(params.x, idx); }
  126. if (Is_dropout) {
  127. dmask0.store_to(params.dmask, idx);
  128. if (has_x1) { dmask1.store_to(params.dmask1, idx); }
  129. }
  130. idx += VEC_COLS_PER_LDG;
  131. }
  132. }
  133. static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
  134. const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
  135. const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
  136. const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
  137. auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
  138. // Need to convert to int, otherwise the subtraction will wrap around.
  139. const index_t valid_partial_vecs_in_warp =
  140. std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
  141. int(THREADS_PER_WARP));
  142. return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
  143. };
  144. stats_t s = stats.template compute<Is_even_cols>(
  145. xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
  146. );
  147. compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
  148. compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
  149. if( bidn == 0 && warp_n == 0 && lane == 0 ) {
  150. mu_ptr[row] = mu;
  151. }
  152. compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
  153. if( bidn == 0 && warp_n == 0 && lane == 0 ) {
  154. rs_ptr[row] = rs;
  155. }
  156. idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  157. #pragma unroll
  158. for( int it = 0; it < LDGS; it++ ) {
  159. if (Is_even_cols || (it < num_valid_ldgs)) {
  160. Ovec z0;
  161. Ovec z1;
  162. #pragma unroll
  163. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  164. compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
  165. compute_t g0_ij = gamma0[it].data.elt[jt];
  166. compute_t b0_ij = beta0[it].data.elt[jt];
  167. z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij);
  168. if (!Tied_norm) {
  169. compute_t g1_ij = gamma1[it].data.elt[jt];
  170. compute_t b1_ij = beta1[it].data.elt[jt];
  171. z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij);
  172. }
  173. }
  174. z0.store_to(params.z, idx);
  175. if (!Tied_norm) { z1.store_to(params.z1, idx); }
  176. idx += VEC_COLS_PER_LDG;
  177. }
  178. }
  179. }
  180. }
  181. } // namespace layer_norm
  182. using namespace layer_norm;
  183. template<
  184. typename weight_t,
  185. typename input_t,
  186. typename residual_t,
  187. typename output_t,
  188. typename compute_t,
  189. typename index_t,
  190. int HIDDEN_SIZE,
  191. int CTAS_PER_ROW,
  192. int WARPS_M,
  193. int WARPS_N,
  194. int BYTES_PER_LDG
  195. >
  196. void launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
  197. using Kernel_traits = Kernel_traits<weight_t,
  198. input_t,
  199. residual_t,
  200. output_t,
  201. compute_t,
  202. index_t,
  203. HIDDEN_SIZE,
  204. CTAS_PER_ROW,
  205. WARPS_M,
  206. WARPS_N,
  207. BYTES_PER_LDG
  208. >;
  209. bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
  210. bool tied_norm = launch_params.params.gamma1 == nullptr;
  211. BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
  212. BOOL_SWITCH(tied_norm, TiedNormConst, [&] {
  213. BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
  214. auto kernel = &ln_parallel_residual_fwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;
  215. if( configure_params ) {
  216. int ctas_per_sm;
  217. CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  218. &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
  219. launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
  220. const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
  221. launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
  222. launch_params.barrier_size = 0;
  223. launch_params.workspace_bytes = 0;
  224. if(Kernel_traits::CTAS_PER_ROW > 1) {
  225. launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
  226. launch_params.workspace_bytes = launch_params.params.ctas_per_col
  227. * Kernel_traits::WARPS_M
  228. * Kernel_traits::CTAS_PER_ROW
  229. * sizeof(typename Kernel_traits::Stats::stats_t)
  230. * 2;
  231. }
  232. return;
  233. }
  234. if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
  235. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
  236. }
  237. auto stream = launch_params.stream;
  238. auto ctas_per_col = launch_params.params.ctas_per_col;
  239. if( Kernel_traits::CTAS_PER_ROW == 1 ) {
  240. kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
  241. } else {
  242. dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
  243. dim3 block(Kernel_traits::THREADS_PER_CTA);
  244. void *params_ = (void *)&launch_params.params;
  245. cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
  246. }
  247. });
  248. });
  249. });
  250. }