ln_bwd_kernels.cuh 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. #pragma once
  2. #include "ln.h"
  3. #include "ln_utils.cuh"
  4. #include "ln_kernel_traits.h"
  5. #include "static_switch.h"
  6. namespace layer_norm {
  7. template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
  8. __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
  9. void ln_bwd_kernel(layer_norm::BwdParams params) {
  10. enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
  11. enum { WARPS_M = Ktraits::WARPS_M };
  12. enum { WARPS_N = Ktraits::WARPS_N };
  13. enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
  14. enum { COLS = Ktraits::COLS };
  15. enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
  16. enum { LDGS = Ktraits::LDGS };
  17. enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
  18. enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
  19. enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
  20. using input_t = typename Ktraits::input_t;
  21. using compute_t = typename Ktraits::compute_t;
  22. using index_t = typename Ktraits::index_t;
  23. using mask_t = typename Ktraits::mask_t;
  24. using Ivec = typename Ktraits::Ivec;
  25. using Rvec = typename Ktraits::Rvec;
  26. using Ovec = typename Ktraits::Ovec;
  27. using Wvec = typename Ktraits::Wvec;
  28. using Cvec = typename Ktraits::Cvec;
  29. using Mvec = typename Ktraits::Mvec;
  30. using Reducer = typename Ktraits::Reducer;
  31. using reduce_t = typename Reducer::Type;
  32. extern __shared__ char smem_[];
  33. const bool has_residual = params.dresidual != nullptr;
  34. const bool prenorm = params.dx != nullptr;
  35. const index_t tidx = threadIdx.x;
  36. const index_t bidn = blockIdx.x % CTAS_PER_ROW;
  37. const index_t bidm = blockIdx.x / CTAS_PER_ROW;
  38. const index_t lane = tidx % THREADS_PER_WARP;
  39. const index_t warp = tidx / THREADS_PER_WARP;
  40. const index_t warp_m = warp / Ktraits::WARPS_N;
  41. const index_t warp_n = warp % Ktraits::WARPS_N;
  42. const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
  43. const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
  44. const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
  45. static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
  46. const input_t *rowscale = static_cast<input_t *>(params.rowscale);
  47. const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
  48. const index_t *z_subset = static_cast<index_t *>(params.z_subset);
  49. Cvec dzy_sum[LDGS];
  50. Cvec dz_sum[LDGS];
  51. Cvec dcolscale_sum[LDGS];
  52. memset(dzy_sum, 0, sizeof(dzy_sum));
  53. memset(dz_sum, 0, sizeof(dz_sum));
  54. if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
  55. compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
  56. char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
  57. Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
  58. Sum<reduce_t> sum;
  59. const index_t num_valid_ldgs =
  60. ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
  61. Wvec gamma[LDGS];
  62. Wvec colscale[LDGS];
  63. index_t idx = c;
  64. #pragma unroll
  65. for( int it = 0; it < LDGS; it++ ) {
  66. if (Is_even_cols || (it < num_valid_ldgs)) {
  67. gamma[it].load_from(params.gamma, idx);
  68. if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
  69. idx += Ktraits::VEC_COLS_PER_LDG;
  70. }
  71. }
  72. // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
  73. // last blocks with syncthreads!
  74. // grid stride over rows
  75. #pragma unroll 1
  76. for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
  77. const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
  78. const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
  79. const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
  80. const int row_z = !Has_subset ? row + 1 : z_subset[row];
  81. const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
  82. const bool load_dz = !Has_subset || row_z > 0;
  83. const bool save_dx0 = !Has_subset || row_x0 > 0;
  84. Mvec dmask[LDGS];
  85. Rvec dx[LDGS];
  86. compute_t dy[LDGS * NUM_ELTS];
  87. compute_t y[LDGS * NUM_ELTS];
  88. compute_t mdy_local = 0.f;
  89. compute_t mdyy_local = 0.f;
  90. // If dz is not loaded, then dy should be 0 and we don't care about the value of y.
  91. if (load_dz) {
  92. index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  93. index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
  94. index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
  95. #pragma unroll
  96. for( int it = 0; it < LDGS; it++ ) {
  97. if (Is_even_cols || (it < num_valid_ldgs)) {
  98. Rvec x;
  99. Ovec dz;
  100. dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
  101. if (prenorm) { dx[it].load_from(params.dx, idx_x); }
  102. x.load_from(params.x, idx_x);
  103. if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
  104. idx_x += Ktraits::VEC_COLS_PER_LDG;
  105. idx_z += Ktraits::VEC_COLS_PER_LDG;
  106. idx_x0 += Ktraits::VEC_COLS_PER_LDG;
  107. #pragma unroll
  108. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  109. compute_t x_tmp = x.data.elt[jt];
  110. compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
  111. compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
  112. compute_t dz_tmp = dz.data.elt[jt];
  113. mdy_local += dy_tmp;
  114. mdyy_local += dy_tmp * y_tmp;
  115. dy[it * NUM_ELTS + jt] = dy_tmp;
  116. y[it * NUM_ELTS + jt] = y_tmp;
  117. dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
  118. dz_sum[it].data.elt[jt] += dz_tmp;
  119. }
  120. }
  121. }
  122. } else {
  123. index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  124. index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
  125. #pragma unroll
  126. for( int it = 0; it < LDGS; it++ ) {
  127. if (Is_even_cols || (it < num_valid_ldgs)) {
  128. if (prenorm) { dx[it].load_from(params.dx, idx_x); }
  129. if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
  130. idx_x += Ktraits::VEC_COLS_PER_LDG;
  131. idx_x0 += Ktraits::VEC_COLS_PER_LDG;
  132. }
  133. }
  134. }
  135. reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
  136. mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
  137. mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
  138. index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
  139. index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
  140. #pragma unroll
  141. for( int it = 0; it < LDGS; it++ ) {
  142. if (Is_even_cols || (it < num_valid_ldgs)) {
  143. Ivec dx0;
  144. Rvec dresidual;
  145. Ivec x0;
  146. if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
  147. #pragma unroll
  148. for( int jt = 0; jt < NUM_ELTS; jt++ ) {
  149. compute_t dx_tmp_res;
  150. if (load_dz) {
  151. compute_t dy_tmp = dy[it * NUM_ELTS + jt];
  152. compute_t y_tmp = y[it * NUM_ELTS + jt];
  153. compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
  154. dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
  155. } else {
  156. dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
  157. }
  158. if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
  159. if (save_dx0) {
  160. compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
  161. if (Is_dropout) {
  162. dx0_tmp_res *= params.dropout_scale;
  163. if (Has_colscale) {
  164. dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
  165. dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
  166. } else {
  167. dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
  168. }
  169. } else {
  170. if (Has_colscale) {
  171. dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
  172. dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
  173. } else {
  174. dx0.data.elt[jt] = dx0_tmp_res;
  175. }
  176. }
  177. }
  178. }
  179. if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
  180. if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
  181. idx_x += Ktraits::VEC_COLS_PER_LDG;
  182. idx_x0 += Ktraits::VEC_COLS_PER_LDG;
  183. }
  184. }
  185. } // end: grid stride loop
  186. if( WARPS_M == 1 ) {
  187. idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
  188. #pragma unroll
  189. for( int it = 0; it < LDGS; it++ ) {
  190. if (Is_even_cols || (it < num_valid_ldgs)) {
  191. dz_sum[it].store_to(params.dbeta_part, idx);
  192. dzy_sum[it].store_to(params.dgamma_part, idx);
  193. if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
  194. idx += Ktraits::VEC_COLS_PER_LDG;
  195. }
  196. }
  197. } else {
  198. static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
  199. // Finalize reduction of part dgamma and dbeta for this CTA
  200. // by reducing over the rows held across the WARPS_M warps
  201. // Assumption: blockSize divides hidden size.
  202. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
  203. static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
  204. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  205. #pragma unroll
  206. for( int it = 0; it < LDGS; it++ ) {
  207. dz_sum[it].store_to(smem_wgrad, idx);
  208. idx += THREADS_PER_ROW;
  209. }
  210. __syncthreads();
  211. compute_t cta_dz_sum[NUM_RES];
  212. memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
  213. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  214. for( int jt = 0; jt < NUM_RES; jt++ ) {
  215. cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  216. }
  217. }
  218. __syncthreads();
  219. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  220. #pragma unroll
  221. for( int it = 0; it < LDGS; it++ ) {
  222. dzy_sum[it].store_to(smem_wgrad, idx);
  223. idx += THREADS_PER_ROW;
  224. }
  225. __syncthreads();
  226. compute_t cta_dzy_sum[NUM_RES];
  227. memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
  228. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  229. for( int jt = 0; jt < NUM_RES; jt++ ) {
  230. cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  231. }
  232. }
  233. compute_t cta_dcolscale_sum[NUM_RES];
  234. if (Has_colscale) {
  235. __syncthreads();
  236. idx = warp_m * Ktraits::VEC_COLS + tid_r;
  237. #pragma unroll
  238. for( int it = 0; it < LDGS; it++ ) {
  239. dcolscale_sum[it].store_to(smem_wgrad, idx);
  240. idx += THREADS_PER_ROW;
  241. }
  242. __syncthreads();
  243. memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
  244. for( int it = 0; it < ROWS_PER_CTA; it++ ) {
  245. for( int jt = 0; jt < NUM_RES; jt++ ) {
  246. cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
  247. }
  248. }
  249. }
  250. const index_t num_valid_writes
  251. = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
  252. compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
  253. compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
  254. compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
  255. for( int jt = 0; jt < NUM_RES; jt++ ) {
  256. if (Is_even_cols || (jt < num_valid_writes)) {
  257. *dgamma_part = cta_dzy_sum[jt];
  258. dgamma_part += Ktraits::THREADS_PER_CTA;
  259. *dbeta_part = cta_dz_sum[jt];
  260. dbeta_part += Ktraits::THREADS_PER_CTA;
  261. if (Has_colscale) {
  262. *dcolscale_part = cta_dcolscale_sum[jt];
  263. dcolscale_part += Ktraits::THREADS_PER_CTA;
  264. }
  265. }
  266. }
  267. }
  268. }
  269. template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
  270. __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
  271. void ln_bwd_finalize_kernel(BwdParams params)
  272. {
  273. using compute_t = typename Kernel_traits::compute_t;
  274. using weight_t = typename Kernel_traits::weight_t;
  275. using index_t = typename Kernel_traits::index_t;
  276. using Reducer = typename Kernel_traits::Reducer;
  277. using reduce_t = typename Reducer::Type;
  278. Sum<reduce_t> sum;
  279. enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
  280. enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
  281. __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
  282. constexpr uint32_t bidm = 0;
  283. const uint32_t bidn = blockIdx.x;
  284. const uint32_t tidx = threadIdx.x;
  285. const uint32_t warp = tidx / THREADS_PER_WARP;
  286. const uint32_t lane = tidx % THREADS_PER_WARP;
  287. Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
  288. const uint32_t c = bidn * THREADS_PER_WARP + lane;
  289. const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
  290. constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
  291. for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
  292. // Each thread sums over NUM_ELT columns.
  293. Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
  294. memset(&dgamma_local, 0, sizeof(dgamma_local));
  295. memset(&dbeta_local, 0, sizeof(dbeta_local));
  296. if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
  297. if (Is_even_cols || col < params.cols) {
  298. for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
  299. index_t idx = row * params.cols + col;
  300. Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
  301. dbeta_part.load_from(params.dbeta_part, idx);
  302. dgamma_part.load_from(params.dgamma_part, idx);
  303. if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
  304. #pragma unroll
  305. for( int it = 0; it < NUM_ELT; it++ ) {
  306. dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
  307. dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
  308. if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
  309. }
  310. }
  311. }
  312. void * smem_gamma = smem_;
  313. void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
  314. void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
  315. const int write_row = warp;
  316. const int write_col = lane ^ write_row;
  317. const int write_idx = write_row * THREADS_PER_WARP + write_col;
  318. dgamma_local.store_to(smem_gamma, write_idx);
  319. dbeta_local.store_to(smem_beta, write_idx);
  320. if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
  321. __syncthreads();
  322. // It would be probably safe to reuse the first row of smem_beta and smem_gamma
  323. void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
  324. void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
  325. void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
  326. // More than one iter iff ROWS_PER_CTA < 32.
  327. for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
  328. const int read_row = lane;
  329. const int read_col = w ^ read_row;
  330. const int read_idx = read_row * THREADS_PER_WARP + read_col;
  331. memset(&dbeta_local, 0, sizeof(dbeta_local));
  332. memset(&dgamma_local, 0, sizeof(dgamma_local));
  333. if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
  334. // Load beta and gamma transposed
  335. if(read_row < Kernel_traits::ROWS_PER_CTA){
  336. dbeta_local.load_from(smem_beta, read_idx);
  337. dgamma_local.load_from(smem_gamma, read_idx);
  338. if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
  339. }
  340. // Call reducer on the loaded value(s) and convert.
  341. #pragma unroll
  342. for( int it = 0; it < NUM_ELT; it++ ) {
  343. compute_t b_i = dbeta_local.data.elt[it];
  344. compute_t g_i = dgamma_local.data.elt[it];
  345. b_i = reducer.allreduce(b_i, sum);
  346. g_i = reducer.allreduce(g_i, sum);
  347. dgamma_local.data.elt[it] = g_i;
  348. dbeta_local.data.elt[it] = b_i;
  349. if (Has_colscale) {
  350. compute_t cs_i = dcolscale_local.data.elt[it];
  351. cs_i = reducer.allreduce(cs_i, sum);
  352. dcolscale_local.data.elt[it] = cs_i;
  353. }
  354. }
  355. // Leader stores the result at the current column.
  356. if(lane == 0){
  357. dgamma_local.store_to(smem_gamma_out, w);
  358. dbeta_local.store_to(smem_beta_out, w);
  359. if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
  360. }
  361. }
  362. // All writes done.
  363. __syncthreads();
  364. // Pack and store: 2-wide stores with half the threads.
  365. if (Is_even_cols || col_out * 2 < params.cols) {
  366. if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
  367. using src_t = typename TypeToVec2<compute_t>::Type;
  368. using dst_t = typename TypeToVec2<weight_t>::Type;
  369. Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
  370. Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
  371. dgamma_vec2.load_from(smem_gamma_out, lane);
  372. dbeta_vec2.load_from(smem_beta_out, lane);
  373. if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
  374. #pragma unroll
  375. for( int it = 0; it < NUM_ELT; it++ ) {
  376. dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
  377. dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
  378. if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
  379. }
  380. dgamma_out2.store_to(params.dgamma, col_out);
  381. dbeta_out2.store_to(params.dbeta, col_out);
  382. if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
  383. }
  384. }
  385. }
  386. }
  387. } // namespace layer_norm
  388. using namespace layer_norm;
  389. template<
  390. typename weight_t,
  391. typename input_t,
  392. typename residual_t,
  393. typename output_t,
  394. typename compute_t,
  395. typename index_t,
  396. int HIDDEN_SIZE,
  397. int CTAS_PER_ROW,
  398. int WARPS_M,
  399. int WARPS_N,
  400. int BYTES_PER_LDG_MAIN,
  401. int BYTES_PER_LDG_FINAL
  402. >
  403. void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
  404. using Kernel_traits = Kernel_traits<weight_t,
  405. input_t,
  406. residual_t,
  407. output_t,
  408. compute_t,
  409. index_t,
  410. HIDDEN_SIZE,
  411. CTAS_PER_ROW,
  412. WARPS_M,
  413. WARPS_N,
  414. BYTES_PER_LDG_MAIN
  415. >;
  416. bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
  417. bool has_colscale = launch_params.params.colscale != nullptr;
  418. bool has_subset = launch_params.params.x0_subset != nullptr;
  419. bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
  420. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
  421. BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
  422. BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
  423. BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
  424. auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
  425. if( configure_params ) {
  426. int ctas_per_sm;
  427. CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  428. &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
  429. launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
  430. launch_params.barrier_size = 0;
  431. launch_params.workspace_bytes = 0;
  432. if(Kernel_traits::CTAS_PER_ROW > 1) {
  433. launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
  434. launch_params.workspace_bytes = launch_params.params.ctas_per_col
  435. * Kernel_traits::WARPS_M
  436. * Kernel_traits::CTAS_PER_ROW
  437. * sizeof(typename Kernel_traits::reduce_t)
  438. * 2;
  439. }
  440. return;
  441. }
  442. if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
  443. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
  444. }
  445. auto stream = launch_params.stream;
  446. auto ctas_per_col = launch_params.params.ctas_per_col;
  447. if( Kernel_traits::CTAS_PER_ROW == 1 ) {
  448. kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
  449. } else {
  450. dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
  451. dim3 block(Kernel_traits::THREADS_PER_CTA);
  452. void *params_ = (void *)&launch_params.params;
  453. cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
  454. }
  455. using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
  456. weight_t,
  457. input_t,
  458. residual_t,
  459. output_t,
  460. compute_t,
  461. index_t,
  462. HasColscaleConst,
  463. 32 * 32, // THREADS_PER_CTA
  464. BYTES_PER_LDG_FINAL>;
  465. auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
  466. kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
  467. });
  468. });
  469. });
  470. });
  471. }