ln_parallel_residual_bwd_kernels.cuh 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. #pragma once
  2. #include "ln.h"
  3. #include "ln_utils.cuh"
  4. #include "ln_kernel_traits.h"
  5. #include "static_switch.h"
  6. #include "ln_bwd_kernels.cuh"
  7. namespace layer_norm {
  8. template<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>
  9. __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
  10. void ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) {
  11. enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
  12. enum { WARPS_M = Ktraits::WARPS_M };
  13. enum { WARPS_N = Ktraits::WARPS_N };
  14. enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
  15. enum { COLS = Ktraits::COLS };
  16. enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
  17. enum { LDGS = Ktraits::LDGS };
  18. enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
  19. enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
  20. enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
  21. using input_t = typename Ktraits::input_t;
  22. using compute_t = typename Ktraits::compute_t;
  23. using index_t = typename Ktraits::index_t;
  24. using mask_t = typename Ktraits::mask_t;
  25. using Ivec = typename Ktraits::Ivec;
  26. using Rvec = typename Ktraits::Rvec;
  27. using Ovec = typename Ktraits::Ovec;
  28. using Wvec = typename Ktraits::Wvec;
  29. using Cvec = typename Ktraits::Cvec;
  30. using Mvec = typename Ktraits::Mvec;
  31. using Reducer = typename Ktraits::Reducer;
  32. using reduce_t = typename Reducer::Type;
  33. extern __shared__ char smem_[];
  34. const bool has_residual = params.dresidual != nullptr;
  35. const bool has_x1 = params.dx1 != nullptr;
  36. const bool prenorm = params.dx != nullptr;
  37. const index_t tidx = threadIdx.x;
  38. const index_t bidn = blockIdx.x % CTAS_PER_ROW;
  39. const index_t bidm = blockIdx.x / CTAS_PER_ROW;
  40. const index_t lane = tidx % THREADS_PER_WARP;
  41. const index_t warp = tidx / THREADS_PER_WARP;
  42. const index_t warp_m = warp / Ktraits::WARPS_N;
  43. const index_t warp_n = warp % Ktraits::WARPS_N;
  44. const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
  45. const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
  46. const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
  47. static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
  48. Cvec dz0y_sum[LDGS];
  49. Cvec dz0_sum[LDGS];
  50. Cvec dz1y_sum[LDGS];
  51. Cvec dz1_sum[LDGS];
  52. memset(dz0y_sum, 0, sizeof(dz0y_sum));
  53. memset(dz0_sum, 0, sizeof(dz0_sum));
  54. if (!Tied_norm) {
  55. memset(dz1y_sum, 0, sizeof(dz1y_sum));
  56. memset(dz1_sum, 0, sizeof(dz1_sum));
  57. }
  58. compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
  59. char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
  60. Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
  61. Sum<reduce_t> sum;
  62. const index_t num_valid_ldgs =
  63. ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
  64. Wvec gamma0[LDGS];
  65. Wvec gamma1[LDGS];
  66. index_t idx = c;
  67. #pragma unroll
  68. for( int it = 0; it < LDGS; it++ ) {
  69. if (Is_even_cols || (it < num_valid_ldgs)) {
  70. gamma0[it].load_from(params.gamma, idx);
  71. if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); }
  72. idx += Ktraits::VEC_COLS_PER_LDG;
  73. }
  74. }
  75. // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
  76. // last blocks with syncthreads!
  77. // grid stride over rows
  78. #pragma unroll 1
  79. for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
  80. const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
  81. const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
  82. Mvec dmask0[LDGS], dmask1[LDGS];
  83. Rvec dx[LDGS];
  84. compute_t dy[LDGS * NUM_ELTS];
  85. compute_t y[LDGS * NUM_ELTS];
  86. compute_t mdy_local = 0.f;
  87. compute_t mdyy_local = 0.f;
  88. index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  89. #pragma unroll
  90. for( int it = 0; it < LDGS; it++ ) {
  91. if (Is_even_cols || (it < num_valid_ldgs)) {
  92. Rvec x;
  93. Ovec dz0, dz1;
  94. dz0.load_from(params.dz, idx);
  95. if (!Tied_norm) { dz1.load_from(params.dz1, idx); }
  96. if (prenorm) { dx[it].load_from(params.dx, idx); }
  97. x.load_from(params.x, idx);
  98. if (Is_dropout) {
  99. dmask0[it].load_from(params.dmask, idx);
  100. if (has_x1) { dmask1[it].load_from(params.dmask1, idx); }
  101. }
  102. idx += Ktraits::VEC_COLS_PER_LDG;
  103. #pragma unroll
  104. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  105. compute_t x_tmp = x.data.elt[jt];
  106. compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
  107. compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]);
  108. if (!Tied_norm) {
  109. dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]);
  110. }
  111. compute_t dz0_tmp = dz0.data.elt[jt];
  112. compute_t dz1_tmp;
  113. if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; }
  114. mdy_local += dy_tmp;
  115. mdyy_local += dy_tmp * y_tmp;
  116. dy[it * NUM_ELTS + jt] = dy_tmp;
  117. y[it * NUM_ELTS + jt] = y_tmp;
  118. dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp;
  119. dz0_sum[it].data.elt[jt] += dz0_tmp;
  120. if (!Tied_norm) {
  121. dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp;
  122. dz1_sum[it].data.elt[jt] += dz1_tmp;
  123. }
  124. }
  125. }
  126. }
  127. reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
  128. mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
  129. mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
  130. idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  131. #pragma unroll
  132. for( int it = 0; it < LDGS; it++ ) {
  133. if (Is_even_cols || (it < num_valid_ldgs)) {
  134. Ivec dx0, dx1;
  135. Rvec dresidual;
  136. #pragma unroll
  137. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  138. compute_t dx_tmp_res;
  139. compute_t dy_tmp = dy[it * NUM_ELTS + jt];
  140. compute_t y_tmp = y[it * NUM_ELTS + jt];
  141. compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
  142. dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
  143. if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
  144. if (Is_dropout) {
  145. dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f;
  146. if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; }
  147. } else {
  148. dx0.data.elt[jt] = dx_tmp_res;
  149. if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; }
  150. }
  151. }
  152. if (has_residual) { dresidual.store_to(params.dresidual, idx); }
  153. dx0.store_to(params.dx0, idx);
  154. if (has_x1) { dx1.store_to(params.dx1, idx); }
  155. idx += Ktraits::VEC_COLS_PER_LDG;
  156. }
  157. }
  158. } // end: grid stride loop
  159. if( WARPS_M == 1 ) {
  160. idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
  161. #pragma unroll
  162. for( int it = 0; it < LDGS; it++ ) {
  163. if (Is_even_cols || (it < num_valid_ldgs)) {
  164. dz0_sum[it].store_to(params.dbeta_part, idx);
  165. dz0y_sum[it].store_to(params.dgamma_part, idx);
  166. if (!Tied_norm) {
  167. dz1_sum[it].store_to(params.dbeta1_part, idx);
  168. dz1y_sum[it].store_to(params.dgamma1_part, idx);
  169. }
  170. idx += Ktraits::VEC_COLS_PER_LDG;
  171. }
  172. }
  173. } else {
  174. static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
  175. // Finalize reduction of part dgamma and dbeta for this CTA
  176. // by reducing over the rows held across the WARPS_M warps
  177. // Assumption: blockSize divides hidden size.
  178. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
  179. static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
  180. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  181. #pragma unroll
  182. for( int it = 0; it < LDGS; it++ ) {
  183. dz0_sum[it].store_to(smem_wgrad, idx);
  184. idx += THREADS_PER_ROW;
  185. }
  186. __syncthreads();
  187. compute_t cta_dz0_sum[NUM_RES];
  188. memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES);
  189. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  190. for( int jt = 0; jt < NUM_RES; jt++ ) {
  191. cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  192. }
  193. }
  194. __syncthreads();
  195. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  196. #pragma unroll
  197. for( int it = 0; it < LDGS; it++ ) {
  198. dz0y_sum[it].store_to(smem_wgrad, idx);
  199. idx += THREADS_PER_ROW;
  200. }
  201. __syncthreads();
  202. compute_t cta_dz0y_sum[NUM_RES];
  203. memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES);
  204. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  205. for( int jt = 0; jt < NUM_RES; jt++ ) {
  206. cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  207. }
  208. }
  209. compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES];
  210. if (!Tied_norm) {
  211. __syncthreads();
  212. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  213. #pragma unroll
  214. for( int it = 0; it < LDGS; it++ ) {
  215. dz1_sum[it].store_to(smem_wgrad, idx);
  216. idx += THREADS_PER_ROW;
  217. }
  218. __syncthreads();
  219. memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES);
  220. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  221. for( int jt = 0; jt < NUM_RES; jt++ ) {
  222. cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  223. }
  224. }
  225. __syncthreads();
  226. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  227. #pragma unroll
  228. for( int it = 0; it < LDGS; it++ ) {
  229. dz1y_sum[it].store_to(smem_wgrad, idx);
  230. idx += THREADS_PER_ROW;
  231. }
  232. __syncthreads();
  233. memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES);
  234. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  235. for( int jt = 0; jt < NUM_RES; jt++ ) {
  236. cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  237. }
  238. }
  239. }
  240. const index_t num_valid_writes
  241. = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
  242. compute_t *dgamma0_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
  243. compute_t *dbeta0_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
  244. compute_t *dgamma1_part = !Tied_norm ? static_cast<compute_t *>(params.dgamma1_part) + bidm * params.cols + tidx : nullptr;
  245. compute_t *dbeta1_part = !Tied_norm ? static_cast<compute_t *>(params.dbeta1_part) + bidm * params.cols + tidx : nullptr;
  246. for( int jt = 0; jt < NUM_RES; jt++ ) {
  247. if (Is_even_cols || (jt < num_valid_writes)) {
  248. *dgamma0_part = cta_dz0y_sum[jt];
  249. dgamma0_part += Ktraits::THREADS_PER_CTA;
  250. *dbeta0_part = cta_dz0_sum[jt];
  251. dbeta0_part += Ktraits::THREADS_PER_CTA;
  252. if (!Tied_norm) {
  253. *dgamma1_part = cta_dz1y_sum[jt];
  254. dgamma1_part += Ktraits::THREADS_PER_CTA;
  255. *dbeta1_part = cta_dz1_sum[jt];
  256. dbeta1_part += Ktraits::THREADS_PER_CTA;
  257. }
  258. }
  259. }
  260. }
  261. }
  262. template<typename Kernel_traits, bool Is_even_cols>
  263. __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
  264. void ln_parallel_residual_bwd_finalize_kernel(BwdParams params)
  265. {
  266. using compute_t = typename Kernel_traits::compute_t;
  267. using weight_t = typename Kernel_traits::weight_t;
  268. using index_t = typename Kernel_traits::index_t;
  269. using Reducer = typename Kernel_traits::Reducer;
  270. using reduce_t = typename Reducer::Type;
  271. Sum<reduce_t> sum;
  272. enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
  273. enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
  274. // Multiplying by 2 since we have both gamma0 and gamma1
  275. __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA];
  276. constexpr uint32_t bidm = 0;
  277. const uint32_t bidn = blockIdx.x;
  278. const uint32_t tidx = threadIdx.x;
  279. const uint32_t warp = tidx / THREADS_PER_WARP;
  280. const uint32_t lane = tidx % THREADS_PER_WARP;
  281. Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
  282. const uint32_t c = bidn * THREADS_PER_WARP + lane;
  283. const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
  284. constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
  285. for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
  286. // Each thread sums over NUM_ELT columns.
  287. Vec<compute_t, NUM_ELT> dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local;
  288. memset(&dgamma0_local, 0, sizeof(dgamma0_local));
  289. memset(&dbeta0_local, 0, sizeof(dbeta0_local));
  290. memset(&dgamma1_local, 0, sizeof(dgamma1_local));
  291. memset(&dbeta1_local, 0, sizeof(dbeta1_local));
  292. if (Is_even_cols || col < params.cols) {
  293. for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
  294. index_t idx = row * params.cols + col;
  295. Vec<compute_t, NUM_ELT> dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part;
  296. dbeta0_part.load_from(params.dbeta_part, idx);
  297. dgamma0_part.load_from(params.dgamma_part, idx);
  298. dbeta1_part.load_from(params.dbeta1_part, idx);
  299. dgamma1_part.load_from(params.dgamma1_part, idx);
  300. #pragma unroll
  301. for( int it = 0; it < NUM_ELT; it++ ) {
  302. dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it];
  303. dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it];
  304. dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it];
  305. dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it];
  306. }
  307. }
  308. }
  309. void * smem_gamma0 = smem_;
  310. void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
  311. void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
  312. void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
  313. const int write_row = warp;
  314. const int write_col = lane ^ write_row;
  315. const int write_idx = write_row * THREADS_PER_WARP + write_col;
  316. dgamma0_local.store_to(smem_gamma0, write_idx);
  317. dbeta0_local.store_to(smem_beta0, write_idx);
  318. dgamma1_local.store_to(smem_gamma1, write_idx);
  319. dbeta1_local.store_to(smem_beta1, write_idx);
  320. __syncthreads();
  321. // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0
  322. void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
  323. void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
  324. void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
  325. void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT];
  326. // More than one iter iff ROWS_PER_CTA < 32.
  327. for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
  328. const int read_row = lane;
  329. const int read_col = w ^ read_row;
  330. const int read_idx = read_row * THREADS_PER_WARP + read_col;
  331. memset(&dbeta0_local, 0, sizeof(dbeta0_local));
  332. memset(&dgamma0_local, 0, sizeof(dgamma0_local));
  333. memset(&dbeta1_local, 0, sizeof(dbeta1_local));
  334. memset(&dgamma1_local, 0, sizeof(dgamma1_local));
  335. // Load beta and gamma transposed
  336. if(read_row < Kernel_traits::ROWS_PER_CTA){
  337. dbeta0_local.load_from(smem_beta0, read_idx);
  338. dgamma0_local.load_from(smem_gamma0, read_idx);
  339. dbeta1_local.load_from(smem_beta1, read_idx);
  340. dgamma1_local.load_from(smem_gamma1, read_idx);
  341. }
  342. // Call reducer on the loaded value(s) and convert.
  343. #pragma unroll
  344. for( int it = 0; it < NUM_ELT; it++ ) {
  345. compute_t b0_i = dbeta0_local.data.elt[it];
  346. compute_t g0_i = dgamma0_local.data.elt[it];
  347. compute_t b1_i = dbeta1_local.data.elt[it];
  348. compute_t g1_i = dgamma1_local.data.elt[it];
  349. b0_i = reducer.allreduce(b0_i, sum);
  350. g0_i = reducer.allreduce(g0_i, sum);
  351. b1_i = reducer.allreduce(b1_i, sum);
  352. g1_i = reducer.allreduce(g1_i, sum);
  353. dgamma0_local.data.elt[it] = g0_i;
  354. dbeta0_local.data.elt[it] = b0_i;
  355. dgamma1_local.data.elt[it] = g1_i;
  356. dbeta1_local.data.elt[it] = b1_i;
  357. }
  358. // Leader stores the result at the current column.
  359. if(lane == 0){
  360. dgamma0_local.store_to(smem_gamma0_out, w);
  361. dbeta0_local.store_to(smem_beta0_out, w);
  362. dgamma1_local.store_to(smem_gamma1_out, w);
  363. dbeta1_local.store_to(smem_beta1_out, w);
  364. }
  365. }
  366. // All writes done.
  367. __syncthreads();
  368. // Pack and store: 2-wide stores with half the threads.
  369. if (Is_even_cols || col_out * 2 < params.cols) {
  370. if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
  371. using src_t = typename TypeToVec2<compute_t>::Type;
  372. using dst_t = typename TypeToVec2<weight_t>::Type;
  373. Vec<src_t, NUM_ELT> dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2;
  374. Vec<dst_t, NUM_ELT> dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2;
  375. dgamma0_vec2.load_from(smem_gamma0_out, lane);
  376. dbeta0_vec2.load_from(smem_beta0_out, lane);
  377. dgamma1_vec2.load_from(smem_gamma1_out, lane);
  378. dbeta1_vec2.load_from(smem_beta1_out, lane);
  379. #pragma unroll
  380. for( int it = 0; it < NUM_ELT; it++ ) {
  381. dgamma0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma0_vec2.data.elt[it]);
  382. dbeta0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta0_vec2.data.elt[it]);
  383. dgamma1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma1_vec2.data.elt[it]);
  384. dbeta1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta1_vec2.data.elt[it]);
  385. }
  386. dgamma0_out2.store_to(params.dgamma, col_out);
  387. dbeta0_out2.store_to(params.dbeta, col_out);
  388. dgamma1_out2.store_to(params.dgamma1, col_out);
  389. dbeta1_out2.store_to(params.dbeta1, col_out);
  390. }
  391. }
  392. }
  393. }
  394. } // namespace layer_norm
  395. using namespace layer_norm;
  396. template<
  397. typename weight_t,
  398. typename input_t,
  399. typename residual_t,
  400. typename output_t,
  401. typename compute_t,
  402. typename index_t,
  403. int HIDDEN_SIZE,
  404. int CTAS_PER_ROW,
  405. int WARPS_M,
  406. int WARPS_N,
  407. int BYTES_PER_LDG_MAIN,
  408. int BYTES_PER_LDG_FINAL
  409. >
  410. void launch_parallel_residual_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
  411. using Kernel_traits = Kernel_traits<weight_t,
  412. input_t,
  413. residual_t,
  414. output_t,
  415. compute_t,
  416. index_t,
  417. HIDDEN_SIZE,
  418. CTAS_PER_ROW,
  419. WARPS_M,
  420. WARPS_N,
  421. BYTES_PER_LDG_MAIN
  422. >;
  423. bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
  424. bool tied_norm = launch_params.params.gamma1 == nullptr;
  425. bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
  426. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
  427. BOOL_SWITCH(tied_norm, TiedNormConst, [&] {
  428. BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
  429. auto kernel = &ln_parallel_residual_bwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;
  430. if( configure_params ) {
  431. int ctas_per_sm;
  432. CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  433. &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
  434. launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
  435. launch_params.barrier_size = 0;
  436. launch_params.workspace_bytes = 0;
  437. if(Kernel_traits::CTAS_PER_ROW > 1) {
  438. launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
  439. launch_params.workspace_bytes = launch_params.params.ctas_per_col
  440. * Kernel_traits::WARPS_M
  441. * Kernel_traits::CTAS_PER_ROW
  442. * sizeof(typename Kernel_traits::reduce_t)
  443. * 2;
  444. }
  445. return;
  446. }
  447. if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
  448. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
  449. }
  450. auto stream = launch_params.stream;
  451. auto ctas_per_col = launch_params.params.ctas_per_col;
  452. if( Kernel_traits::CTAS_PER_ROW == 1 ) {
  453. kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
  454. } else {
  455. dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
  456. dim3 block(Kernel_traits::THREADS_PER_CTA);
  457. void *params_ = (void *)&launch_params.params;
  458. cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
  459. }
  460. using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
  461. weight_t,
  462. input_t,
  463. residual_t,
  464. output_t,
  465. compute_t,
  466. index_t,
  467. /*HasColscaleConst=*/false,
  468. 32 * 32, // THREADS_PER_CTA
  469. BYTES_PER_LDG_FINAL>;
  470. auto kernel_f = !TiedNormConst
  471. ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>
  472. : &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, /*HasColscaleConst=*/false, IsEvenColsConst>;
  473. kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
  474. });
  475. });
  476. });
  477. }