ln_api.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. #include <torch/extension.h>
  2. #include "ATen/cuda/CUDAContext.h"
  3. #include "ln.h"
  4. /*
  5. Supported Type combinations:
  6. input residual compute weights output
  7. ============================================
  8. fp32 fp32 fp32 fp32 fp32
  9. fp16 fp32 fp32 fp32 fp16
  10. fp16 fp16 fp32 fp32 fp16
  11. bf16 fp32 fp32 fp32 bf16
  12. bf16 bf16 fp32 fp32 bf16
  13. fp16 fp16 fp32 fp16 fp16
  14. bf16 bf16 fp32 bf16 bf16
  15. Remarks:
  16. Output type = Input type
  17. Compute always in FP32
  18. */
  19. namespace layer_norm {
  20. // Create registries and provide runtime versions of config hash functions.
  21. FwdRegistry FWD_FUNCS;
  22. BwdRegistry BWD_FUNCS;
  23. ////////////////////////////////////////////////////////////////////////////////////////////////////
  24. uint32_t get_type_id(torch::Dtype dtype){
  25. if( dtype == torch::kFloat16 ) {
  26. return TypeId<fp16>::Value;
  27. } else if( dtype == torch::kBFloat16 ) {
  28. return TypeId<bf16>::Value;
  29. } else if( dtype == torch::kFloat32 ) {
  30. return TypeId<fp32>::Value;
  31. } else {
  32. TORCH_CHECK(false, "Type not supported: ", dtype);
  33. }
  34. }
  35. ////////////////////////////////////////////////////////////////////////////////////////////////////
  36. uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
  37. using namespace layer_norm;
  38. uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
  39. uint64_t launcher_key = (type_key << 32) | hidden_size;
  40. return launcher_key;
  41. }
  42. } // namespace layer_norm
  43. ////////////////////////////////////////////////////////////////////////////////////////////////////
  44. layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  45. auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  46. if( iter != layer_norm::FWD_FUNCS.end() ) {
  47. return iter->second;
  48. } else {
  49. TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  50. }
  51. }
  52. ////////////////////////////////////////////////////////////////////////////////////////////////////
  53. layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  54. auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  55. if( iter != layer_norm::BWD_FUNCS.end() ) {
  56. return iter->second;
  57. } else {
  58. TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  59. }
  60. }
  61. ////////////////////////////////////////////////////////////////////////////////////////////////////
  62. std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
  63. c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
  64. const at::Tensor &gamma, // hidden_size
  65. const at::Tensor &beta, // hidden_size
  66. c10::optional<const at::Tensor> &rowscale_, // BxS
  67. const float dropout_p,
  68. const float epsilon,
  69. c10::optional<at::Generator> gen_,
  70. bool residual_in_fp32
  71. ) {
  72. auto itype = x0.scalar_type();
  73. auto rtype = x1_.has_value()
  74. ? x1_.value().scalar_type()
  75. : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
  76. auto wtype = gamma.scalar_type();
  77. auto otype = itype;
  78. auto ctype = torch::kFloat32;
  79. auto mtype = torch::kUInt8;
  80. TORCH_CHECK(beta.scalar_type() == wtype);
  81. TORCH_CHECK(x0.is_cuda())
  82. TORCH_CHECK(gamma.is_cuda())
  83. TORCH_CHECK(beta.is_cuda())
  84. TORCH_CHECK(x0.is_contiguous());
  85. auto sizes = x0.sizes();
  86. TORCH_CHECK(sizes.size() == 2);
  87. const int rows = sizes[0];
  88. const int cols = sizes[1];
  89. auto hidden_size = gamma.numel();
  90. if (x1_.has_value()) {
  91. auto x1 = x1_.value();
  92. TORCH_CHECK(x1.is_cuda())
  93. TORCH_CHECK(x1.is_contiguous());
  94. TORCH_CHECK(x1.sizes() == sizes);
  95. }
  96. if (rowscale_.has_value()) {
  97. auto rowscale = rowscale_.value();
  98. TORCH_CHECK(rowscale.is_cuda())
  99. TORCH_CHECK(rowscale.is_contiguous());
  100. TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
  101. TORCH_CHECK(rowscale.scalar_type() == itype);
  102. }
  103. TORCH_CHECK(gamma.sizes() == beta.sizes());
  104. TORCH_CHECK(hidden_size == cols);
  105. TORCH_CHECK(epsilon >= 0.f);
  106. auto opts = x0.options();
  107. bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
  108. at::Tensor x;
  109. if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
  110. at::Tensor dmask;
  111. if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); };
  112. auto z = torch::empty(sizes, opts.dtype(otype));
  113. auto mu = torch::empty({ rows }, opts.dtype(ctype));
  114. auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
  115. layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
  116. launch_params.props = at::cuda::getCurrentDeviceProperties();
  117. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  118. TORCH_CHECK(dropout_p < 1.f);
  119. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  120. launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
  121. launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
  122. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  123. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  124. // Request the kernel launcher.
  125. auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
  126. // Query the kernel-specific launch parameters.
  127. launcher(launch_params, true);
  128. at::Tensor workspace, barrier;
  129. // Set the kernel runtime parameters.
  130. layer_norm::FwdParams &params = launch_params.params;
  131. params.rows = rows;
  132. params.cols = cols;
  133. params.x0 = x0.data_ptr();
  134. params.x = save_x ? x.data_ptr() : nullptr;
  135. params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
  136. params.mu = mu.data_ptr();
  137. params.rs = rsigma.data_ptr();
  138. params.gamma = gamma.data_ptr();
  139. params.beta = beta.data_ptr();
  140. params.z = z.data_ptr();
  141. params.epsilon = epsilon;
  142. params.dropout_scale = 1.f / (1.f - dropout_p);
  143. if (dropout_p > 0.f) {
  144. // number of times random will be generated per thread, to offset philox counter in thc random
  145. // state
  146. int64_t counter_offset = launch_params.elts_per_thread;
  147. // See Note [Acquire lock when using random generators]
  148. {
  149. std::lock_guard<std::mutex> lock(gen->mutex_);
  150. params.philox_args = gen->philox_cuda_state(counter_offset);
  151. }
  152. }
  153. if( launch_params.barrier_size > 0 ) {
  154. auto options = x0.options();
  155. barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
  156. workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
  157. params.workspace = workspace.data_ptr();
  158. params.barrier = barrier.data_ptr<int>();
  159. }
  160. // Launch the kernel.
  161. launcher(launch_params, false);
  162. return { z, x, dmask, mu, rsigma };
  163. }
  164. ////////////////////////////////////////////////////////////////////////////////////////////////////
  165. std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
  166. const at::Tensor &x, // BxSxhidden_size
  167. c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
  168. const at::Tensor &mu, // BxS, FP32!
  169. const at::Tensor &rsigma, // BxS, FP32!
  170. const at::Tensor &gamma, // hidden_size
  171. c10::optional<const at::Tensor> &rowscale_, // BxS
  172. const float dropout_p,
  173. const bool has_residual
  174. ) {
  175. auto itype = dz.scalar_type();
  176. auto rtype = x.scalar_type();
  177. auto wtype = gamma.scalar_type();
  178. auto otype = itype;
  179. auto ctype = torch::kFloat32;
  180. auto mtype = torch::kUInt8;
  181. if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
  182. TORCH_CHECK(dz.dtype() == otype);
  183. TORCH_CHECK(mu.dtype() == ctype);
  184. TORCH_CHECK(rsigma.dtype() == ctype);
  185. TORCH_CHECK(x.is_cuda());
  186. TORCH_CHECK(dz.is_cuda());
  187. TORCH_CHECK(mu.is_cuda());
  188. TORCH_CHECK(rsigma.is_cuda());
  189. TORCH_CHECK(gamma.is_cuda());
  190. TORCH_CHECK(x.is_contiguous());
  191. TORCH_CHECK(dz.is_contiguous());
  192. auto sizes = x.sizes();
  193. TORCH_CHECK(sizes.size() == 2);
  194. TORCH_CHECK(dz.sizes() == sizes);
  195. auto rows = sizes[0];
  196. auto cols = sizes[1];
  197. if (dmask_.has_value()) {
  198. auto dmask = dmask_.value();
  199. TORCH_CHECK(dmask.dtype() == mtype);
  200. TORCH_CHECK(dmask.is_cuda());
  201. TORCH_CHECK(dmask.is_contiguous());
  202. TORCH_CHECK(dmask.sizes() == sizes);
  203. }
  204. if (rowscale_.has_value()) {
  205. auto rowscale = rowscale_.value();
  206. TORCH_CHECK(rowscale.is_cuda())
  207. TORCH_CHECK(rowscale.is_contiguous());
  208. TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
  209. TORCH_CHECK(rowscale.scalar_type() == itype);
  210. }
  211. auto hidden_size = gamma.numel();
  212. TORCH_CHECK(mu.numel() == rows);
  213. TORCH_CHECK(mu.sizes() == rsigma.sizes());
  214. TORCH_CHECK(gamma.numel() == cols);
  215. auto opts = x.options();
  216. auto dx0 = torch::empty_like(x, opts.dtype(itype));
  217. at::Tensor dx1;
  218. if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
  219. auto dgamma = torch::empty_like(gamma);
  220. auto dbeta = torch::empty_like(gamma);
  221. layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
  222. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  223. launch_params.props = at::cuda::getCurrentDeviceProperties();
  224. TORCH_CHECK(dropout_p < 1.f);
  225. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  226. launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
  227. launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
  228. auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
  229. launcher(launch_params, true, /*prenorm=*/false);
  230. auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  231. auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  232. at::Tensor workspace, barrier;
  233. layer_norm::BwdParams &params = launch_params.params;
  234. params.rows = rows;
  235. params.cols = cols;
  236. params.x = x.data_ptr();
  237. params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
  238. params.mu = mu.data_ptr();
  239. params.rs = rsigma.data_ptr();
  240. params.gamma = gamma.data_ptr();
  241. params.dz = dz.data_ptr();
  242. params.dx0 = dx0.data_ptr();
  243. params.dbeta = dbeta.data_ptr();
  244. params.dgamma = dgamma.data_ptr();
  245. params.dbeta_part = dbeta_part.data_ptr();
  246. params.dgamma_part = dgamma_part.data_ptr();
  247. params.dropout_scale = 1.f / (1.f - dropout_p);
  248. if( launch_params.barrier_size > 0 ) {
  249. // TODO Any way to avoid this?
  250. barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
  251. workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
  252. params.workspace = workspace.data_ptr();
  253. params.barrier = barrier.data_ptr<int>();
  254. }
  255. launcher(launch_params, false, /*prenorm=*/false);
  256. return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
  257. }
  258. ////////////////////////////////////////////////////////////////////////////////////////////////////
  259. std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
  260. const at::Tensor &dx, // BxSxhidden_size
  261. const at::Tensor &x, // BxSxhidden_size
  262. c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
  263. const at::Tensor &mu, // BxS, FP32!
  264. const at::Tensor &rsigma, // BxS, FP32!
  265. const at::Tensor &gamma, // hidden_size
  266. c10::optional<const at::Tensor> &rowscale_, // BxS
  267. const float dropout_p,
  268. const bool has_residual
  269. ) {
  270. auto itype = dz.scalar_type();
  271. auto rtype = x.scalar_type();
  272. auto wtype = gamma.scalar_type();
  273. auto otype = itype;
  274. auto ctype = torch::kFloat32;
  275. auto mtype = torch::kUInt8;
  276. if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
  277. TORCH_CHECK(dz.dtype() == otype);
  278. TORCH_CHECK(dx.dtype() == rtype);
  279. TORCH_CHECK(mu.dtype() == ctype);
  280. TORCH_CHECK(rsigma.dtype() == ctype);
  281. TORCH_CHECK(x.is_cuda());
  282. TORCH_CHECK(dz.is_cuda());
  283. TORCH_CHECK(dx.is_cuda());
  284. TORCH_CHECK(mu.is_cuda());
  285. TORCH_CHECK(rsigma.is_cuda());
  286. TORCH_CHECK(gamma.is_cuda());
  287. TORCH_CHECK(x.is_contiguous());
  288. TORCH_CHECK(dz.is_contiguous());
  289. TORCH_CHECK(dx.is_contiguous());
  290. auto sizes = x.sizes();
  291. TORCH_CHECK(sizes.size() == 2);
  292. TORCH_CHECK(dz.sizes() == sizes);
  293. TORCH_CHECK(dx.sizes() == sizes);
  294. auto rows = sizes[0];
  295. auto cols = sizes[1];
  296. if (dmask_.has_value()) {
  297. auto dmask = dmask_.value();
  298. TORCH_CHECK(dmask.dtype() == mtype);
  299. TORCH_CHECK(dmask.is_cuda());
  300. TORCH_CHECK(dmask.is_contiguous());
  301. TORCH_CHECK(dmask.sizes() == sizes);
  302. }
  303. if (rowscale_.has_value()) {
  304. auto rowscale = rowscale_.value();
  305. TORCH_CHECK(rowscale.is_cuda())
  306. TORCH_CHECK(rowscale.is_contiguous());
  307. TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
  308. TORCH_CHECK(rowscale.scalar_type() == itype);
  309. }
  310. auto hidden_size = gamma.numel();
  311. TORCH_CHECK(mu.numel() == rows);
  312. TORCH_CHECK(mu.sizes() == rsigma.sizes());
  313. TORCH_CHECK(gamma.numel() == cols);
  314. auto opts = x.options();
  315. auto dx0 = torch::empty_like(x, opts.dtype(itype));
  316. at::Tensor dx1;
  317. if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
  318. auto dgamma = torch::empty_like(gamma);
  319. auto dbeta = torch::empty_like(gamma);
  320. layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
  321. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  322. launch_params.props = at::cuda::getCurrentDeviceProperties();
  323. TORCH_CHECK(dropout_p < 1.f);
  324. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  325. launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
  326. launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
  327. // TODO: how to set template param for launcher
  328. auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
  329. launcher(launch_params, true, /*prenorm=*/true);
  330. auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  331. auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  332. at::Tensor workspace, barrier;
  333. layer_norm::BwdParams &params = launch_params.params;
  334. params.rows = rows;
  335. params.cols = cols;
  336. params.x = x.data_ptr();
  337. params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
  338. params.mu = mu.data_ptr();
  339. params.rs = rsigma.data_ptr();
  340. params.gamma = gamma.data_ptr();
  341. params.dz = dz.data_ptr();
  342. params.dx = dx.data_ptr();
  343. params.dx0 = dx0.data_ptr();
  344. params.dbeta = dbeta.data_ptr();
  345. params.dgamma = dgamma.data_ptr();
  346. params.dbeta_part = dbeta_part.data_ptr();
  347. params.dgamma_part = dgamma_part.data_ptr();
  348. params.dropout_scale = 1.f / (1.f - dropout_p);
  349. if( launch_params.barrier_size > 0 ) {
  350. // TODO Any way to avoid this?
  351. barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
  352. workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
  353. params.workspace = workspace.data_ptr();
  354. params.barrier = barrier.data_ptr<int>();
  355. }
  356. launcher(launch_params, false, /*prenorm=*/true);
  357. return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
  358. }
  359. ////////////////////////////////////////////////////////////////////////////////////////////////////
  360. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  361. m.doc() = "CUDA DropoutAddLayerNorm";
  362. m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
  363. m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
  364. m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
  365. }