ln_api.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. #include <torch/extension.h>
  2. #include "ATen/cuda/CUDAContext.h"
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "ln.h"
  5. /*
  6. Supported Type combinations:
  7. input residual compute weights output
  8. ============================================
  9. fp32 fp32 fp32 fp32 fp32
  10. fp16 fp32 fp32 fp32 fp16
  11. fp16 fp16 fp32 fp32 fp16
  12. bf16 fp32 fp32 fp32 bf16
  13. bf16 bf16 fp32 fp32 bf16
  14. fp16 fp16 fp32 fp16 fp16
  15. bf16 bf16 fp32 bf16 bf16
  16. Remarks:
  17. Output type = Input type
  18. Compute always in FP32
  19. */
  20. namespace layer_norm {
  21. // Create registries and provide runtime versions of config hash functions.
  22. FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
  23. BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
  24. ////////////////////////////////////////////////////////////////////////////////////////////////////
  25. uint32_t get_type_id(torch::Dtype dtype){
  26. if( dtype == torch::kFloat16 ) {
  27. return TypeId<fp16>::Value;
  28. } else if( dtype == torch::kBFloat16 ) {
  29. return TypeId<bf16>::Value;
  30. } else if( dtype == torch::kFloat32 ) {
  31. return TypeId<fp32>::Value;
  32. } else {
  33. TORCH_CHECK(false, "Type not supported: ", dtype);
  34. }
  35. }
  36. ////////////////////////////////////////////////////////////////////////////////////////////////////
  37. uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
  38. using namespace layer_norm;
  39. uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
  40. uint64_t launcher_key = (type_key << 32) | hidden_size;
  41. return launcher_key;
  42. }
  43. } // namespace layer_norm
  44. ////////////////////////////////////////////////////////////////////////////////////////////////////
  45. layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  46. auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  47. if( iter != layer_norm::FWD_FUNCS.end() ) {
  48. return iter->second;
  49. } else {
  50. TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  51. }
  52. }
  53. ////////////////////////////////////////////////////////////////////////////////////////////////////
  54. layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  55. auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  56. if( iter != layer_norm::BWD_FUNCS.end() ) {
  57. return iter->second;
  58. } else {
  59. TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  60. }
  61. }
  62. ////////////////////////////////////////////////////////////////////////////////////////////////////
  63. layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  64. auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  65. if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
  66. return iter->second;
  67. } else {
  68. TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  69. }
  70. }
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
  73. auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
  74. if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
  75. return iter->second;
  76. } else {
  77. TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
  78. }
  79. }
  80. ////////////////////////////////////////////////////////////////////////////////////////////////////
  81. std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
  82. std::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
  83. const at::Tensor &gamma, // hidden_size
  84. std::optional<const at::Tensor> &beta_, // hidden_size
  85. std::optional<const at::Tensor> &rowscale_, // BxS
  86. std::optional<const at::Tensor> &colscale_, // hidden_size
  87. std::optional<const at::Tensor> &x0_subset_, // BxS
  88. std::optional<const at::Tensor> &z_subset_, // BxS
  89. const float dropout_p,
  90. const float epsilon,
  91. const float rowscale_const,
  92. const int64_t z_numrows,
  93. std::optional<at::Generator> gen_,
  94. bool residual_in_fp32=false,
  95. bool is_rms_norm=false
  96. ) {
  97. auto itype = x0.scalar_type();
  98. auto rtype = residual_.has_value()
  99. ? residual_.value().scalar_type()
  100. : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
  101. auto wtype = gamma.scalar_type();
  102. auto otype = itype;
  103. auto ctype = torch::kFloat32;
  104. auto mtype = torch::kUInt8;
  105. TORCH_CHECK(x0.is_cuda());
  106. TORCH_CHECK(gamma.is_cuda());
  107. TORCH_CHECK(x0.is_contiguous());
  108. // c10::IntArrayRef does not own the storage, so we need to construct a vector.
  109. // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
  110. // blah is then deallocated.
  111. std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
  112. auto sizes = c10::IntArrayRef(sizes_vec);
  113. TORCH_CHECK(x0.dim() == 2);
  114. TORCH_CHECK(sizes.size() == 2);
  115. const int rows = sizes[0];
  116. const int cols = sizes[1];
  117. auto hidden_size = gamma.numel();
  118. TORCH_CHECK(hidden_size == cols);
  119. if (beta_.has_value()) {
  120. auto beta = beta_.value();
  121. TORCH_CHECK(beta.dtype() == wtype);
  122. TORCH_CHECK(beta.is_cuda());
  123. TORCH_CHECK(beta.is_contiguous());
  124. TORCH_CHECK(beta.sizes() == gamma.sizes());
  125. }
  126. if (residual_.has_value()) {
  127. auto residual = residual_.value();
  128. TORCH_CHECK(residual.is_cuda());
  129. TORCH_CHECK(residual.is_contiguous());
  130. TORCH_CHECK(residual.sizes() == sizes);
  131. }
  132. if (rowscale_.has_value()) {
  133. auto rowscale = rowscale_.value();
  134. TORCH_CHECK(rowscale.is_cuda());
  135. TORCH_CHECK(rowscale.is_contiguous());
  136. TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
  137. TORCH_CHECK(rowscale.dtype() == itype);
  138. }
  139. if (colscale_.has_value()) {
  140. auto colscale = colscale_.value();
  141. TORCH_CHECK(colscale.is_cuda());
  142. TORCH_CHECK(colscale.is_contiguous());
  143. TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
  144. TORCH_CHECK(colscale.dtype() == wtype);
  145. }
  146. if (x0_subset_.has_value()) {
  147. auto x0_subset = x0_subset_.value();
  148. TORCH_CHECK(x0_subset.is_cuda());
  149. TORCH_CHECK(x0_subset.is_contiguous());
  150. TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
  151. TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
  152. TORCH_CHECK(z_subset_.has_value());
  153. auto z_subset = z_subset_.value();
  154. TORCH_CHECK(z_subset.is_cuda());
  155. TORCH_CHECK(z_subset.is_contiguous());
  156. TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
  157. TORCH_CHECK(z_subset.dtype() == torch::kInt32);
  158. }
  159. TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
  160. TORCH_CHECK(epsilon >= 0.f);
  161. // Otherwise the kernel will be launched from cuda:0 device
  162. at::cuda::CUDAGuard device_guard{x0.device()};
  163. auto opts = x0.options();
  164. bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
  165. at::Tensor x;
  166. if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
  167. at::Tensor dmask;
  168. if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
  169. auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
  170. auto mu = torch::empty({ rows }, opts.dtype(ctype));
  171. auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
  172. layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
  173. launch_params.props = at::cuda::getCurrentDeviceProperties();
  174. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  175. TORCH_CHECK(dropout_p < 1.f);
  176. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  177. launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
  178. launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
  179. launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
  180. launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
  181. launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
  182. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  183. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  184. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  185. const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
  186. // Request the kernel launcher.
  187. auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
  188. // Set the kernel runtime parameters.
  189. layer_norm::FwdParams &params = launch_params.params;
  190. params.rows = rows;
  191. params.cols = cols;
  192. params.x0 = x0.data_ptr();
  193. params.x = save_x ? x.data_ptr() : nullptr;
  194. params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
  195. params.mu = mu.data_ptr();
  196. params.rs = rsigma.data_ptr();
  197. params.gamma = gamma.data_ptr();
  198. params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
  199. params.z = z.data_ptr();
  200. params.epsilon = epsilon;
  201. params.dropout_scale = 1.f / (1.f - dropout_p);
  202. params.inverse_cols = 1.f / float(params.cols);
  203. params.rowscale_const = rowscale_const;
  204. params.is_rms_norm = is_rms_norm;
  205. // Query the kernel-specific launch parameters.
  206. launcher(launch_params, true);
  207. at::Tensor workspace, barrier;
  208. if (dropout_p > 0.f) {
  209. // number of times random will be generated per thread, to offset philox counter in thc random
  210. // state
  211. int64_t counter_offset = launch_params.elts_per_thread;
  212. // See Note [Acquire lock when using random generators]
  213. {
  214. std::lock_guard<std::mutex> lock(gen->mutex_);
  215. params.philox_args = gen->philox_cuda_state(counter_offset);
  216. }
  217. }
  218. if( launch_params.barrier_size > 0 ) {
  219. auto options = x0.options();
  220. barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
  221. workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
  222. params.workspace = workspace.data_ptr();
  223. params.barrier = barrier.data_ptr<int>();
  224. }
  225. // Launch the kernel.
  226. launcher(launch_params, false);
  227. return { z, x, dmask, mu, rsigma };
  228. }
  229. ////////////////////////////////////////////////////////////////////////////////////////////////////
  230. std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
  231. std::optional<const at::Tensor> &dx_, // BxSxhidden_size
  232. const at::Tensor &x, // BxSxhidden_size
  233. std::optional<const at::Tensor> &x0_, // BxSxhidden_size
  234. std::optional<const at::Tensor> &dmask_, // BxSxhidden_size
  235. const at::Tensor &mu, // BxS, FP32!
  236. const at::Tensor &rsigma, // BxS, FP32!
  237. const at::Tensor &gamma, // hidden_size
  238. std::optional<const at::Tensor> &rowscale_, // BxS
  239. std::optional<const at::Tensor> &colscale_, // hidden_size
  240. std::optional<const at::Tensor> &x0_subset_, // BxS
  241. std::optional<const at::Tensor> &z_subset_, // BxS
  242. const float dropout_p,
  243. const float rowscale_const,
  244. const int64_t x0_numrows,
  245. const bool has_residual,
  246. bool is_rms_norm=false
  247. ) {
  248. auto itype = dz.scalar_type();
  249. auto rtype = x.scalar_type();
  250. auto wtype = gamma.scalar_type();
  251. auto otype = itype;
  252. auto ctype = torch::kFloat32;
  253. auto mtype = torch::kUInt8;
  254. if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
  255. TORCH_CHECK(dz.dtype() == otype);
  256. TORCH_CHECK(mu.dtype() == ctype);
  257. TORCH_CHECK(rsigma.dtype() == ctype);
  258. TORCH_CHECK(x.is_cuda());
  259. TORCH_CHECK(dz.is_cuda());
  260. TORCH_CHECK(mu.is_cuda());
  261. TORCH_CHECK(rsigma.is_cuda());
  262. TORCH_CHECK(gamma.is_cuda());
  263. TORCH_CHECK(x.is_contiguous());
  264. TORCH_CHECK(dz.is_contiguous());
  265. auto sizes = x.sizes();
  266. TORCH_CHECK(sizes.size() == 2);
  267. auto rows = sizes[0];
  268. auto cols = sizes[1];
  269. TORCH_CHECK(dz.dim() == 2);
  270. TORCH_CHECK(dz.size(1) == cols);
  271. auto hidden_size = gamma.numel();
  272. TORCH_CHECK(hidden_size == cols);
  273. // c10::IntArrayRef does not own the storage, so we need to construct a vector.
  274. // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
  275. // blah is then deallocated.
  276. std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
  277. auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
  278. if (dx_.has_value()) {
  279. auto dx = dx_.value();
  280. TORCH_CHECK(dx.dtype() == rtype);
  281. TORCH_CHECK(dx.is_cuda());
  282. TORCH_CHECK(dx.is_contiguous());
  283. TORCH_CHECK(dx.sizes() == sizes);
  284. }
  285. if (dmask_.has_value()) {
  286. auto dmask = dmask_.value();
  287. TORCH_CHECK(dmask.dtype() == mtype);
  288. TORCH_CHECK(dmask.is_cuda());
  289. TORCH_CHECK(dmask.is_contiguous());
  290. TORCH_CHECK(dmask.sizes() == x0_sizes);
  291. }
  292. if (rowscale_.has_value()) {
  293. auto rowscale = rowscale_.value();
  294. TORCH_CHECK(rowscale.is_cuda());
  295. TORCH_CHECK(rowscale.is_contiguous());
  296. TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
  297. TORCH_CHECK(rowscale.dtype() == itype);
  298. }
  299. if (colscale_.has_value()) {
  300. auto colscale = colscale_.value();
  301. TORCH_CHECK(colscale.is_cuda());
  302. TORCH_CHECK(colscale.is_contiguous());
  303. TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
  304. TORCH_CHECK(colscale.dtype() == wtype);
  305. TORCH_CHECK(x0_.has_value());
  306. auto x0 = x0_.value();
  307. TORCH_CHECK(x0.is_cuda());
  308. TORCH_CHECK(x0.is_contiguous());
  309. TORCH_CHECK(x0.sizes() == x0_sizes);
  310. TORCH_CHECK(x0.dtype() == itype);
  311. }
  312. if (x0_subset_.has_value()) {
  313. auto x0_subset = x0_subset_.value();
  314. TORCH_CHECK(x0_subset.is_cuda());
  315. TORCH_CHECK(x0_subset.is_contiguous());
  316. TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
  317. TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
  318. TORCH_CHECK(z_subset_.has_value());
  319. auto z_subset = z_subset_.value();
  320. TORCH_CHECK(z_subset.is_cuda());
  321. TORCH_CHECK(z_subset.is_contiguous());
  322. TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
  323. TORCH_CHECK(z_subset.dtype() == torch::kInt32);
  324. }
  325. TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
  326. TORCH_CHECK(mu.numel() == rows);
  327. TORCH_CHECK(mu.sizes() == rsigma.sizes());
  328. TORCH_CHECK(gamma.numel() == cols);
  329. // Otherwise the kernel will be launched from cuda:0 device
  330. at::cuda::CUDAGuard device_guard{dz.device()};
  331. auto opts = x.options();
  332. auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
  333. at::Tensor dresidual;
  334. if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
  335. auto dgamma = torch::empty_like(gamma);
  336. auto dbeta = torch::empty_like(gamma);
  337. at::Tensor dcolscale;
  338. if (colscale_.has_value()) {
  339. dcolscale = torch::empty_like(colscale_.value());
  340. }
  341. layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
  342. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  343. launch_params.props = at::cuda::getCurrentDeviceProperties();
  344. TORCH_CHECK(dropout_p < 1.f);
  345. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  346. launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
  347. launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
  348. launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
  349. launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
  350. launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
  351. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  352. const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
  353. auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
  354. launcher(launch_params, true);
  355. auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  356. auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  357. at::Tensor dcolscale_part;
  358. if (colscale_.has_value()) {
  359. dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  360. }
  361. at::Tensor workspace, barrier;
  362. layer_norm::BwdParams &params = launch_params.params;
  363. params.rows = rows;
  364. params.cols = cols;
  365. params.x = x.data_ptr();
  366. params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
  367. params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
  368. params.mu = mu.data_ptr();
  369. params.rs = rsigma.data_ptr();
  370. params.gamma = gamma.data_ptr();
  371. params.dz = dz.data_ptr();
  372. params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
  373. params.dx0 = dx0.data_ptr();
  374. params.dbeta = dbeta.data_ptr();
  375. params.dgamma = dgamma.data_ptr();
  376. params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
  377. params.dbeta_part = dbeta_part.data_ptr();
  378. params.dgamma_part = dgamma_part.data_ptr();
  379. params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
  380. params.dropout_scale = 1.f / (1.f - dropout_p);
  381. params.inverse_cols = 1.f / float(params.cols);
  382. params.rowscale_const = rowscale_const;
  383. params.is_rms_norm = is_rms_norm;
  384. if( launch_params.barrier_size > 0 ) {
  385. // TODO Any way to avoid this?
  386. barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
  387. workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
  388. params.workspace = workspace.data_ptr();
  389. params.barrier = barrier.data_ptr<int>();
  390. }
  391. launcher(launch_params, false);
  392. std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
  393. if (colscale_.has_value()) {
  394. result.push_back(dcolscale);
  395. result.push_back(dcolscale_part);
  396. }
  397. return result;
  398. }
  399. ////////////////////////////////////////////////////////////////////////////////////////////////////
  400. std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
  401. const at::Tensor &x0, // Input: BxSxhidden_size
  402. std::optional<const at::Tensor> &x1_, // Input: BxSxhidden_size
  403. std::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
  404. const at::Tensor &gamma0, // hidden_size
  405. std::optional<const at::Tensor> &beta0_, // hidden_size
  406. std::optional<const at::Tensor> &gamma1_, // hidden_size
  407. std::optional<const at::Tensor> &beta1_, // hidden_size
  408. const float dropout_p,
  409. const float epsilon,
  410. std::optional<at::Generator> gen_,
  411. bool residual_in_fp32=false,
  412. bool is_rms_norm=false
  413. ) {
  414. auto itype = x0.scalar_type();
  415. auto rtype = residual_.has_value()
  416. ? residual_.value().scalar_type()
  417. : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
  418. auto wtype = gamma0.scalar_type();
  419. auto otype = itype;
  420. auto ctype = torch::kFloat32;
  421. auto mtype = torch::kUInt8;
  422. TORCH_CHECK(x0.is_cuda());
  423. TORCH_CHECK(gamma0.is_cuda());
  424. TORCH_CHECK(x0.is_contiguous());
  425. const auto sizes = x0.sizes();
  426. TORCH_CHECK(x0.dim() == 2);
  427. const int rows = sizes[0];
  428. const int cols = sizes[1];
  429. auto hidden_size = gamma0.numel();
  430. TORCH_CHECK(hidden_size == cols);
  431. if (x1_.has_value()) {
  432. auto x1 = x1_.value();
  433. TORCH_CHECK(x1.is_cuda());
  434. TORCH_CHECK(x1.is_contiguous());
  435. TORCH_CHECK(x1.sizes() == sizes);
  436. }
  437. if (residual_.has_value()) {
  438. auto residual = residual_.value();
  439. TORCH_CHECK(residual.is_cuda());
  440. TORCH_CHECK(residual.is_contiguous());
  441. TORCH_CHECK(residual.sizes() == sizes);
  442. }
  443. if (beta0_.has_value()) {
  444. auto beta0 = beta0_.value();
  445. TORCH_CHECK(beta0.dtype() == wtype);
  446. TORCH_CHECK(beta0.is_cuda());
  447. TORCH_CHECK(beta0.is_contiguous());
  448. TORCH_CHECK(beta0.sizes() == gamma0.sizes());
  449. }
  450. if (gamma1_.has_value()) {
  451. auto gamma1 = gamma1_.value();
  452. TORCH_CHECK(gamma1.dtype() == wtype);
  453. TORCH_CHECK(gamma1.is_cuda());
  454. TORCH_CHECK(gamma1.is_contiguous());
  455. TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
  456. }
  457. if (beta1_.has_value()) {
  458. auto beta1 = beta1_.value();
  459. TORCH_CHECK(beta1.dtype() == wtype);
  460. TORCH_CHECK(beta1.is_cuda());
  461. TORCH_CHECK(beta1.is_contiguous());
  462. TORCH_CHECK(beta1.sizes() == gamma0.sizes());
  463. }
  464. TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
  465. TORCH_CHECK(epsilon >= 0.f);
  466. // Otherwise the kernel will be launched from cuda:0 device
  467. at::cuda::CUDAGuard device_guard{x0.device()};
  468. auto opts = x0.options();
  469. bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
  470. at::Tensor x;
  471. if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
  472. at::Tensor dmask0, dmask1;
  473. if (dropout_p > 0.f) {
  474. dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
  475. if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
  476. };
  477. auto z0 = torch::empty(sizes, opts.dtype(otype));
  478. at::Tensor z1;
  479. if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
  480. auto mu = torch::empty({ rows }, opts.dtype(ctype));
  481. auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
  482. layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
  483. launch_params.props = at::cuda::getCurrentDeviceProperties();
  484. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  485. TORCH_CHECK(dropout_p < 1.f);
  486. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  487. launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
  488. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  489. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  490. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  491. const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
  492. // Request the kernel launcher.
  493. auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
  494. // Set the kernel runtime parameters.
  495. layer_norm::FwdParams &params = launch_params.params;
  496. params.rows = rows;
  497. params.cols = cols;
  498. params.x0 = x0.data_ptr();
  499. params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
  500. params.x = save_x ? x.data_ptr() : nullptr;
  501. params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
  502. params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
  503. params.mu = mu.data_ptr();
  504. params.rs = rsigma.data_ptr();
  505. params.gamma = gamma0.data_ptr();
  506. params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
  507. params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
  508. params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
  509. params.z = z0.data_ptr();
  510. params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
  511. params.epsilon = epsilon;
  512. params.dropout_scale = 1.f / (1.f - dropout_p);
  513. params.inverse_cols = 1.f / float(params.cols);
  514. params.is_rms_norm = is_rms_norm;
  515. // Query the kernel-specific launch parameters.
  516. launcher(launch_params, true);
  517. at::Tensor workspace, barrier;
  518. if (dropout_p > 0.f) {
  519. // number of times random will be generated per thread, to offset philox counter in thc random
  520. // state
  521. int64_t counter_offset = 2 * launch_params.elts_per_thread;
  522. // See Note [Acquire lock when using random generators]
  523. {
  524. std::lock_guard<std::mutex> lock(gen->mutex_);
  525. params.philox_args = gen->philox_cuda_state(counter_offset);
  526. }
  527. }
  528. if( launch_params.barrier_size > 0 ) {
  529. auto options = x0.options();
  530. barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
  531. workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
  532. params.workspace = workspace.data_ptr();
  533. params.barrier = barrier.data_ptr<int>();
  534. }
  535. // Launch the kernel.
  536. launcher(launch_params, false);
  537. return { z0, z1, x, dmask0, dmask1, mu, rsigma };
  538. }
  539. ////////////////////////////////////////////////////////////////////////////////////////////////////
  540. std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
  541. const at::Tensor &dz0, // BxSxhidden_size
  542. std::optional<const at::Tensor> &dz1_, // BxSxhidden_size
  543. std::optional<const at::Tensor> &dx_, // BxSxhidden_size
  544. const at::Tensor &x, // BxSxhidden_size
  545. std::optional<const at::Tensor> &dmask0_, // BxSxhidden_size
  546. std::optional<const at::Tensor> &dmask1_, // BxSxhidden_size
  547. const at::Tensor &mu, // BxS, FP32!
  548. const at::Tensor &rsigma, // BxS, FP32!
  549. const at::Tensor &gamma0, // hidden_size
  550. std::optional<const at::Tensor> &gamma1_, // hidden_size
  551. const float dropout_p,
  552. const bool has_x1,
  553. const bool has_residual,
  554. bool is_rms_norm=false
  555. ) {
  556. auto itype = dz0.scalar_type();
  557. auto rtype = x.scalar_type();
  558. auto wtype = gamma0.scalar_type();
  559. auto otype = itype;
  560. auto ctype = torch::kFloat32;
  561. auto mtype = torch::kUInt8;
  562. if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
  563. TORCH_CHECK(dz0.dtype() == otype);
  564. TORCH_CHECK(dz0.dtype() == otype);
  565. TORCH_CHECK(mu.dtype() == ctype);
  566. TORCH_CHECK(rsigma.dtype() == ctype);
  567. TORCH_CHECK(x.is_cuda());
  568. TORCH_CHECK(dz0.is_cuda());
  569. TORCH_CHECK(mu.is_cuda());
  570. TORCH_CHECK(rsigma.is_cuda());
  571. TORCH_CHECK(gamma0.is_cuda());
  572. TORCH_CHECK(x.is_contiguous());
  573. TORCH_CHECK(dz0.is_contiguous());
  574. auto sizes = x.sizes();
  575. TORCH_CHECK(sizes.size() == 2);
  576. auto rows = sizes[0];
  577. auto cols = sizes[1];
  578. TORCH_CHECK(dz0.dim() == 2);
  579. TORCH_CHECK(dz0.size(1) == cols);
  580. auto hidden_size = gamma0.numel();
  581. TORCH_CHECK(hidden_size == cols);
  582. if (dz1_.has_value()) {
  583. auto dz1 = dz1_.value();
  584. TORCH_CHECK(dz1.dtype() == otype);
  585. TORCH_CHECK(dz1.is_cuda());
  586. TORCH_CHECK(dz1.is_contiguous());
  587. TORCH_CHECK(dz1.sizes() == sizes);
  588. TORCH_CHECK(gamma1_.has_value());
  589. auto gamma1 = gamma1_.value();
  590. TORCH_CHECK(gamma1.dtype() == wtype);
  591. TORCH_CHECK(gamma1.is_cuda());
  592. TORCH_CHECK(gamma1.is_contiguous());
  593. TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
  594. }
  595. if (dx_.has_value()) {
  596. auto dx = dx_.value();
  597. TORCH_CHECK(dx.dtype() == rtype);
  598. TORCH_CHECK(dx.is_cuda());
  599. TORCH_CHECK(dx.is_contiguous());
  600. TORCH_CHECK(dx.sizes() == sizes);
  601. }
  602. if (dmask0_.has_value()) {
  603. auto dmask0 = dmask0_.value();
  604. TORCH_CHECK(dmask0.dtype() == mtype);
  605. TORCH_CHECK(dmask0.is_cuda());
  606. TORCH_CHECK(dmask0.is_contiguous());
  607. TORCH_CHECK(dmask0.sizes() == sizes);
  608. if (has_x1) {
  609. TORCH_CHECK(dmask1_.has_value());
  610. auto dmask1 = dmask1_.value();
  611. TORCH_CHECK(dmask1.dtype() == mtype);
  612. TORCH_CHECK(dmask1.is_cuda());
  613. TORCH_CHECK(dmask1.is_contiguous());
  614. TORCH_CHECK(dmask1.sizes() == sizes);
  615. }
  616. }
  617. TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
  618. TORCH_CHECK(mu.numel() == rows);
  619. TORCH_CHECK(mu.sizes() == rsigma.sizes());
  620. // Otherwise the kernel will be launched from cuda:0 device
  621. at::cuda::CUDAGuard device_guard{dz0.device()};
  622. auto opts = x.options();
  623. auto dx0 = torch::empty(sizes, opts.dtype(itype));
  624. at::Tensor dx1;
  625. if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
  626. at::Tensor dresidual;
  627. if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
  628. auto dgamma0 = torch::empty_like(gamma0);
  629. auto dbeta0 = torch::empty_like(gamma0);
  630. at::Tensor dgamma1, dbeta1;
  631. if (gamma1_.has_value()) {
  632. dgamma1 = torch::empty_like(gamma0);
  633. dbeta1 = torch::empty_like(gamma0);
  634. }
  635. layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
  636. launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
  637. launch_params.props = at::cuda::getCurrentDeviceProperties();
  638. TORCH_CHECK(dropout_p < 1.f);
  639. launch_params.params.dropout_keep_p = 1.f - dropout_p;
  640. launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
  641. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  642. const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
  643. auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
  644. launcher(launch_params, true);
  645. auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  646. auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
  647. at::Tensor dgamma1_part, dbeta1_part;
  648. if (gamma1_.has_value()) {
  649. dgamma1_part = torch::zeros_like(dgamma0_part);
  650. dbeta1_part = torch::zeros_like(dbeta0_part);
  651. }
  652. at::Tensor workspace, barrier;
  653. layer_norm::BwdParams &params = launch_params.params;
  654. params.rows = rows;
  655. params.cols = cols;
  656. params.x = x.data_ptr();
  657. params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
  658. params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
  659. params.mu = mu.data_ptr();
  660. params.rs = rsigma.data_ptr();
  661. params.gamma = gamma0.data_ptr();
  662. params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
  663. params.dz = dz0.data_ptr();
  664. params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
  665. params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
  666. params.dx0 = dx0.data_ptr();
  667. params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
  668. params.dbeta = dbeta0.data_ptr();
  669. params.dgamma = dgamma0.data_ptr();
  670. params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
  671. params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
  672. params.dbeta_part = dbeta0_part.data_ptr();
  673. params.dgamma_part = dgamma0_part.data_ptr();
  674. params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
  675. params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
  676. params.dropout_scale = 1.f / (1.f - dropout_p);
  677. params.inverse_cols = 1.f / float(params.cols);
  678. params.is_rms_norm = is_rms_norm;
  679. if( launch_params.barrier_size > 0 ) {
  680. // TODO Any way to avoid this?
  681. barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
  682. workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
  683. params.workspace = workspace.data_ptr();
  684. params.barrier = barrier.data_ptr<int>();
  685. }
  686. launcher(launch_params, false);
  687. std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
  688. return result;
  689. }
  690. ////////////////////////////////////////////////////////////////////////////////////////////////////
  691. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  692. m.doc() = "CUDA DropoutAddLayerNorm";
  693. m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
  694. py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"),
  695. py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
  696. py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
  697. py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
  698. m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
  699. py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
  700. py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
  701. py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
  702. py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
  703. m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel",
  704. py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"),
  705. py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"),
  706. py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
  707. m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel",
  708. py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"),
  709. py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"),
  710. py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
  711. }