ln_bwd_semi_cuda_kernel.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. #include "ln.h"
  2. #include "ln_utils.cuh"
  3. #include "ln_kernel_traits.h"
  4. #include "ln_bwd_kernels.cuh"
  5. #include "static_switch.h"
  6. using namespace layer_norm;
  7. template<
  8. typename weight_t,
  9. typename input_t,
  10. typename residual_t,
  11. typename output_t,
  12. typename compute_t,
  13. typename index_t,
  14. int HIDDEN_SIZE,
  15. int CTAS_PER_ROW,
  16. int WARPS_M,
  17. int WARPS_N,
  18. int BYTES_PER_LDG_MAIN,
  19. int BYTES_PER_LDG_FINAL
  20. >
  21. void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
  22. using Kernel_traits = Kernel_traits<weight_t,
  23. input_t,
  24. residual_t,
  25. output_t,
  26. compute_t,
  27. index_t,
  28. HIDDEN_SIZE,
  29. CTAS_PER_ROW,
  30. WARPS_M,
  31. WARPS_N,
  32. BYTES_PER_LDG_MAIN
  33. >;
  34. bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
  35. bool has_residual = launch_params.params.dx1 != nullptr;
  36. bool has_rowscale = launch_params.params.rowscale != nullptr;
  37. BOOL_SWITCH(prenorm, PrenormConst, [&] {
  38. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
  39. BOOL_SWITCH(has_residual, HasResidualConst, [&] {
  40. BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
  41. auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
  42. if( configure_params ) {
  43. int ctas_per_sm;
  44. cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  45. &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
  46. launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
  47. launch_params.barrier_size = 0;
  48. launch_params.workspace_bytes = 0;
  49. if(Kernel_traits::CTAS_PER_ROW > 1) {
  50. launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
  51. launch_params.workspace_bytes = launch_params.params.ctas_per_col
  52. * Kernel_traits::WARPS_M
  53. * Kernel_traits::CTAS_PER_ROW
  54. * sizeof(typename Kernel_traits::reduce_t)
  55. * 2;
  56. }
  57. return;
  58. }
  59. if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
  60. CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
  61. }
  62. auto stream = launch_params.stream;
  63. auto ctas_per_col = launch_params.params.ctas_per_col;
  64. if( Kernel_traits::CTAS_PER_ROW == 1 ) {
  65. kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
  66. } else {
  67. dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
  68. dim3 block(Kernel_traits::THREADS_PER_CTA);
  69. void *params_ = (void *)&launch_params.params;
  70. cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
  71. }
  72. using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
  73. weight_t,
  74. input_t,
  75. residual_t,
  76. output_t,
  77. compute_t,
  78. index_t,
  79. 32 * 32, // THREADS_PER_CTA
  80. BYTES_PER_LDG_FINAL>;
  81. auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
  82. kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
  83. });
  84. });
  85. });
  86. });
  87. }
  88. // Create backward launch function and register. Macro signature:
  89. // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
  90. REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  91. REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  92. REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  93. REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  94. REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  95. REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  96. REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  97. REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  98. REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  99. REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  100. REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  101. REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  102. REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  103. REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  104. REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  105. REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  106. REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  107. REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  108. REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  109. REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  110. REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  111. REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
  112. REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  113. REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
  114. REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  115. REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  116. REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
  117. REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  118. REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
  119. REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
  120. REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  121. REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  122. REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  123. REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  124. REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  125. REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  126. REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  127. REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  128. REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  129. REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  130. REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
  131. REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
  132. REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
  133. REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
  134. REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
  135. REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
  136. REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
  137. REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
  138. REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
  139. REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
  140. REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  141. REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  142. REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  143. REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  144. REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  145. REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  146. REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  147. REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  148. REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  149. REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  150. REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  151. REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  152. REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  153. REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  154. REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  155. REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  156. REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  157. REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  158. REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  159. REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  160. REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  161. REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  162. REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  163. REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  164. REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  165. REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  166. REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  167. REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  168. REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  169. REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  170. REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  171. REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  172. REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  173. REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  174. REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  175. REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  176. REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  177. REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  178. REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  179. REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  180. REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  181. REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  182. REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  183. REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
  184. REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  185. REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  186. REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
  187. REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  188. REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
  189. REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
  190. // TD [2022-04-22] Disable most of these to speed up compile time
  191. // REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
  192. // REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  193. // REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  194. // REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  195. // REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  196. // REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
  197. // REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
  198. // REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
  199. // REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
  200. // REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  201. // REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
  202. // REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  203. // REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
  204. // REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
  205. // REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
  206. // REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
  207. // REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
  208. // REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
  209. // REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
  210. // REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
  211. // REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
  212. // REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
  213. // REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
  214. // REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
  215. // REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
  216. // REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
  217. // REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
  218. // REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
  219. // REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
  220. // REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
  221. // REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
  222. // REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
  223. // REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
  224. // REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
  225. // REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
  226. // REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
  227. // REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
  228. // REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
  229. // REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
  230. // REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
  231. // REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
  232. // REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
  233. // REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
  234. // REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
  235. // REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
  236. // REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
  237. // REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
  238. // REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
  239. // REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
  240. // REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
  241. // REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
  242. // REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
  243. // REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
  244. // REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
  245. // REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
  246. // REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
  247. // REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
  248. // REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
  249. // REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
  250. // REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
  251. // REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
  252. // REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
  253. // REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
  254. // REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
  255. // REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
  256. // REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
  257. // REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
  258. // REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
  259. // REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
  260. // REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
  261. // REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
  262. // REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
  263. // REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
  264. // REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
  265. // REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
  266. // REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
  267. // REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
  268. // REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
  269. // REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
  270. // REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
  271. // REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
  272. // REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
  273. // REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
  274. // REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
  275. // REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
  276. // REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
  277. // REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
  278. // REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
  279. // REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
  280. // REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
  281. // REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
  282. // REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
  283. // REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
  284. // REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
  285. // REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
  286. // REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
  287. // REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);