ln_fwd_kernels.cuh 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. #pragma once
  2. #ifdef OLD_GENERATOR_PATH
  3. #include <ATen/CUDAGeneratorImpl.h>
  4. #else
  5. #include <ATen/cuda/CUDAGeneratorImpl.h>
  6. #endif
  7. #include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
  8. #include <curand_kernel.h>
  9. #include "ln.h"
  10. #include "ln_utils.cuh"
  11. #include "ln_kernel_traits.h"
  12. #include "static_switch.h"
  13. namespace layer_norm {
  14. template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
  15. __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
  16. void ln_fwd_kernel(FwdParams params) {
  17. enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
  18. enum { WARPS_N = Ktraits::WARPS_N };
  19. enum { WARPS_M = Ktraits::WARPS_M };
  20. enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
  21. enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
  22. enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
  23. enum { LDGS = Ktraits::LDGS };
  24. enum { NUM_ELTS = Ktraits::NUM_ELTS };
  25. enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
  26. using input_t = typename Ktraits::input_t;
  27. using residual_t = typename Ktraits::residual_t;
  28. using output_t = typename Ktraits::output_t;
  29. using index_t = typename Ktraits::index_t;
  30. using compute_t = typename Ktraits::compute_t;
  31. using mask_t = typename Ktraits::mask_t;
  32. using Ivec = typename Ktraits::Ivec;
  33. using Rvec = typename Ktraits::Rvec;
  34. using Ovec = typename Ktraits::Ovec;
  35. using Wvec = typename Ktraits::Wvec;
  36. using Cvec = typename Ktraits::Cvec;
  37. using Mvec = typename Ktraits::Mvec;
  38. using Stats = typename Ktraits::Stats;
  39. using stats_t = typename Stats::stats_t;
  40. const bool has_residual = params.residual != nullptr;
  41. const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
  42. extern __shared__ char smem_[];
  43. const index_t tidx = threadIdx.x;
  44. const index_t bidn = blockIdx.x % CTAS_PER_ROW;
  45. const index_t bidm = blockIdx.x / CTAS_PER_ROW;
  46. const index_t lane = tidx % THREADS_PER_WARP;
  47. const index_t warp = tidx / THREADS_PER_WARP;
  48. const index_t warp_m = warp / WARPS_N;
  49. const index_t warp_n = warp % WARPS_N;
  50. const index_t r = bidm * ROWS_PER_CTA + warp_m;
  51. const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
  52. Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
  53. compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
  54. compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
  55. const input_t *rowscale = static_cast<input_t *>(params.rowscale);
  56. const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
  57. const index_t *z_subset = static_cast<index_t *>(params.z_subset);
  58. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
  59. curandStatePhilox4_32_10_t state;
  60. if (Is_dropout) {
  61. auto seeds = at::cuda::philox::unpack(params.philox_args);
  62. const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
  63. curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
  64. }
  65. const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
  66. Wvec gamma[LDGS];
  67. Wvec beta[LDGS];
  68. Wvec colscale[LDGS];
  69. index_t idx = c;
  70. #pragma unroll
  71. for( int it = 0; it < LDGS; it++ ) {
  72. if (Is_even_cols || (it < num_valid_ldgs)) {
  73. gamma[it].load_from(params.gamma, idx);
  74. if (params.beta != nullptr) {
  75. beta[it].load_from(params.beta, idx);
  76. } else {
  77. beta[it].zero_();
  78. }
  79. if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
  80. idx += VEC_COLS_PER_LDG;
  81. }
  82. }
  83. for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
  84. const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
  85. const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
  86. const int row_z = !Has_subset ? row + 1 : z_subset[row];
  87. const bool load_x0 = !Has_subset || row_x0 > 0;
  88. index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  89. index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
  90. compute_t xf[LDGS * NUM_ELTS];
  91. #pragma unroll
  92. for( int it = 0; it < LDGS; it++ ) {
  93. if (Is_even_cols || (it < num_valid_ldgs)) {
  94. Ivec x0;
  95. Rvec residual;
  96. Rvec x;
  97. Mvec dmask;
  98. if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
  99. if (has_residual) { residual.load_from(params.residual, idx_x); }
  100. #pragma unroll
  101. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  102. // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
  103. // the more efficient curand_uniform4.
  104. compute_t x_ij;
  105. if (load_x0) {
  106. mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
  107. if (Is_dropout) { dmask.data.elt[jt] = keep; }
  108. compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
  109. x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
  110. if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
  111. x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
  112. } else {
  113. x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
  114. }
  115. if (save_x) { x.data.elt[jt] = x_ij; }
  116. xf[it * NUM_ELTS + jt] = x_ij;
  117. }
  118. if (save_x) { x.store_to(params.x, idx_x); }
  119. if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
  120. idx_x += VEC_COLS_PER_LDG;
  121. idx_x0 += VEC_COLS_PER_LDG;
  122. }
  123. }
  124. static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
  125. const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
  126. const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
  127. const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
  128. auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
  129. // Need to convert to int, otherwise the subtraction will wrap around.
  130. const index_t valid_partial_vecs_in_warp =
  131. std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
  132. int(THREADS_PER_WARP));
  133. return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
  134. };
  135. stats_t s = stats.template compute<Is_even_cols>(
  136. xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
  137. );
  138. compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
  139. compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
  140. if( bidn == 0 && warp_n == 0 && lane == 0 ) {
  141. mu_ptr[row] = mu;
  142. }
  143. compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
  144. if( bidn == 0 && warp_n == 0 && lane == 0 ) {
  145. rs_ptr[row] = rs;
  146. }
  147. const bool save_z = !Has_subset || row_z > 0;
  148. if (save_z) {
  149. index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
  150. #pragma unroll
  151. for( int it = 0; it < LDGS; it++ ) {
  152. if (Is_even_cols || (it < num_valid_ldgs)) {
  153. Ovec z;
  154. #pragma unroll
  155. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  156. compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
  157. compute_t g_ij = gamma[it].data.elt[jt];
  158. compute_t b_ij = beta[it].data.elt[jt];
  159. z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
  160. }
  161. z.store_to(params.z, idx_z);
  162. idx_z += VEC_COLS_PER_LDG;
  163. }
  164. }
  165. }
  166. }
  167. }
  168. } // namespace layer_norm
  169. using namespace layer_norm;
  170. template<
  171. typename weight_t,
  172. typename input_t,
  173. typename residual_t,
  174. typename output_t,
  175. typename compute_t,
  176. typename index_t,
  177. int HIDDEN_SIZE,
  178. int CTAS_PER_ROW,
  179. int WARPS_M,
  180. int WARPS_N,
  181. int BYTES_PER_LDG
  182. >
  183. void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
  184. using Kernel_traits = Kernel_traits<weight_t,
  185. input_t,
  186. residual_t,
  187. output_t,
  188. compute_t,
  189. index_t,
  190. HIDDEN_SIZE,
  191. CTAS_PER_ROW,
  192. WARPS_M,
  193. WARPS_N,
  194. BYTES_PER_LDG
  195. >;
  196. bool has_colscale = launch_params.params.colscale != nullptr;
  197. bool has_subset = launch_params.params.x0_subset != nullptr;
  198. bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
  199. BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
  200. BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
  201. BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
  202. BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
  203. auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
  204. if( configure_params ) {
  205. int ctas_per_sm;
  206. CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  207. &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
  208. launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
  209. const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
  210. launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
  211. launch_params.barrier_size = 0;
  212. launch_params.workspace_bytes = 0;
  213. if(Kernel_traits::CTAS_PER_ROW > 1) {
  214. launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
  215. launch_params.workspace_bytes = launch_params.params.ctas_per_col
  216. * Kernel_traits::WARPS_M
  217. * Kernel_traits::CTAS_PER_ROW
  218. * sizeof(typename Kernel_traits::Stats::stats_t)
  219. * 2;
  220. }
  221. return;
  222. }
  223. if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
  224. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
  225. }
  226. auto stream = launch_params.stream;
  227. auto ctas_per_col = launch_params.params.ctas_per_col;
  228. if( Kernel_traits::CTAS_PER_ROW == 1 ) {
  229. kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
  230. } else {
  231. dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
  232. dim3 block(Kernel_traits::THREADS_PER_CTA);
  233. void *params_ = (void *)&launch_params.params;
  234. cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
  235. }
  236. });
  237. });
  238. });
  239. });
  240. }