layer_norm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import dropout_layer_norm
  4. import torch
  5. from torch.nn import init
  6. def maybe_align(x, alignment_in_bytes=16):
  7. """Assume that x already has last dim divisible by alignment_in_bytes"""
  8. # TD [2023-07-04] I'm not 100% sure that clone will align the memory
  9. # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
  10. return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
  11. def _dropout_add_layer_norm_forward(
  12. x0,
  13. residual,
  14. gamma,
  15. beta,
  16. rowscale,
  17. colscale,
  18. dropout_p,
  19. epsilon,
  20. residual_in_fp32=False,
  21. is_rms_norm=False,
  22. ):
  23. """Assume that arguments are contiguous and aligned to 16 bytes"""
  24. hidden_size = gamma.numel()
  25. x0mat = x0.view((-1, hidden_size))
  26. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  27. rowscale = rowscale.view(-1) if rowscale is not None else None
  28. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  29. x0mat,
  30. residualmat,
  31. gamma,
  32. beta,
  33. rowscale,
  34. colscale,
  35. None,
  36. None,
  37. dropout_p,
  38. epsilon,
  39. 1.0,
  40. 0,
  41. None,
  42. residual_in_fp32,
  43. is_rms_norm,
  44. )
  45. # dmask is None if dropout_p == 0.0
  46. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  47. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  48. def _dropout_add_layer_norm_backward(
  49. dz,
  50. dx,
  51. x,
  52. x0,
  53. dmask,
  54. mu,
  55. rsigma,
  56. gamma,
  57. rowscale,
  58. colscale,
  59. dropout_p,
  60. has_residual,
  61. is_rms_norm=False,
  62. ):
  63. """Assume that arguments are contiguous and aligned to 16 bytes
  64. dx == None means that it was a post-norm architecture
  65. (x = drop(x0) + residual was not returned in the fwd).
  66. x0 must not be None if we have colscale.
  67. """
  68. hidden_size = gamma.numel()
  69. xmat = x.view((-1, hidden_size))
  70. dzmat = dz.view(xmat.shape)
  71. dxmat = dx.view(xmat.shape) if dx is not None else None
  72. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  73. rowscale = rowscale.view(-1) if rowscale is not None else None
  74. if colscale is not None:
  75. assert x0 is not None, "x0 is required to compute the gradient of colscale"
  76. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  77. dzmat,
  78. dxmat,
  79. xmat,
  80. x0mat,
  81. dmask,
  82. mu,
  83. rsigma,
  84. gamma,
  85. rowscale,
  86. colscale,
  87. None,
  88. None,
  89. dropout_p,
  90. 1.0,
  91. 0,
  92. has_residual,
  93. is_rms_norm,
  94. )
  95. # dresidualmat is None if not has_residual
  96. if colscale is None:
  97. return dx0mat, dresidualmat, dgamma, dbeta
  98. else:
  99. dcolscale = rest[0]
  100. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  101. def _dropout_add_layer_norm_subset_forward(
  102. x0,
  103. residual,
  104. gamma,
  105. beta,
  106. colscale,
  107. x0_subset,
  108. out_subset,
  109. dropout_p,
  110. epsilon,
  111. rowscale_const,
  112. out_numrows,
  113. residual_in_fp32=False,
  114. is_rms_norm=False,
  115. ):
  116. """Assume that arguments are contiguous and aligned to 16 bytes"""
  117. hidden_size = gamma.numel()
  118. x0mat = x0.view((-1, hidden_size))
  119. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  120. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  121. out_subset = out_subset.view(-1) if out_subset is not None else None
  122. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  123. x0mat,
  124. residualmat,
  125. gamma,
  126. beta,
  127. None,
  128. colscale,
  129. x0_subset,
  130. out_subset,
  131. dropout_p,
  132. epsilon,
  133. rowscale_const,
  134. out_numrows,
  135. None,
  136. residual_in_fp32,
  137. is_rms_norm,
  138. )
  139. # dmask is None if dropout_p == 0.0
  140. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  141. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  142. def _dropout_add_layer_norm_subset_backward(
  143. dz,
  144. dx,
  145. x,
  146. x0,
  147. dmask,
  148. mu,
  149. rsigma,
  150. gamma,
  151. colscale,
  152. x0_subset,
  153. out_subset,
  154. dropout_p,
  155. rowscale_const,
  156. x0_numrows,
  157. has_residual,
  158. is_rms_norm=False,
  159. ):
  160. """Assume that arguments are contiguous and aligned to 16 bytes
  161. dx == None means that it was a post-norm architecture
  162. (x = drop(x0) + residual was not returned in the fwd).
  163. x0 must not be None if we have colscale.
  164. """
  165. hidden_size = gamma.numel()
  166. xmat = x.view((-1, hidden_size))
  167. dzmat = dz.view(-1, hidden_size)
  168. dxmat = dx.view(xmat.shape) if dx is not None else None
  169. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  170. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  171. out_subset = out_subset.view(-1) if out_subset is not None else None
  172. if colscale is not None:
  173. assert x0 is not None, "x0 is required to compute the gradient of colscale"
  174. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  175. dzmat,
  176. dxmat,
  177. xmat,
  178. x0mat,
  179. dmask,
  180. mu,
  181. rsigma,
  182. gamma,
  183. None,
  184. colscale,
  185. x0_subset,
  186. out_subset,
  187. dropout_p,
  188. rowscale_const,
  189. x0_numrows,
  190. has_residual,
  191. is_rms_norm,
  192. )
  193. # dresidualmat is None if not has_residual
  194. if colscale is None:
  195. return dx0mat, dresidualmat, dgamma, dbeta
  196. else:
  197. dcolscale = rest[0]
  198. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  199. def _dropout_add_layer_norm_parallel_residual_forward(
  200. x0,
  201. x1,
  202. residual,
  203. gamma0,
  204. beta0,
  205. gamma1,
  206. beta1,
  207. dropout_p,
  208. epsilon,
  209. residual_in_fp32=False,
  210. is_rms_norm=False,
  211. ):
  212. """Assume that arguments are contiguous and aligned to 16 bytes"""
  213. hidden_size = gamma0.numel()
  214. x0mat = x0.view((-1, hidden_size))
  215. x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
  216. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  217. (
  218. z0mat,
  219. z1mat,
  220. xmat,
  221. dmask0,
  222. dmask1,
  223. mu,
  224. rsigma,
  225. ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
  226. x0mat,
  227. x1mat,
  228. residualmat,
  229. gamma0,
  230. beta0,
  231. gamma1,
  232. beta1,
  233. dropout_p,
  234. epsilon,
  235. None,
  236. residual_in_fp32,
  237. is_rms_norm,
  238. )
  239. # dmask0 and dmask1 are None if dropout_p == 0.0
  240. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  241. return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
  242. def _dropout_add_layer_norm_parallel_residual_backward(
  243. dz0,
  244. dz1,
  245. dx,
  246. x,
  247. dmask0,
  248. dmask1,
  249. mu,
  250. rsigma,
  251. gamma0,
  252. gamma1,
  253. dropout_p,
  254. has_x1,
  255. has_residual,
  256. is_rms_norm=False,
  257. ):
  258. """Assume that arguments are contiguous and aligned to 16 bytes
  259. dx == None means that it was a post-norm architecture
  260. (x = drop(x0) + residual was not returned in the fwd).
  261. """
  262. hidden_size = gamma0.numel()
  263. xmat = x.view((-1, hidden_size))
  264. dz0mat = dz0.view(xmat.shape)
  265. dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
  266. dxmat = dx.view(xmat.shape) if dx is not None else None
  267. (
  268. dx0mat,
  269. dx1mat,
  270. dresidualmat,
  271. dgamma0,
  272. dbeta0,
  273. dgamma1,
  274. dbeta1,
  275. *rest,
  276. ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
  277. dz0mat,
  278. dz1mat,
  279. dxmat,
  280. xmat,
  281. dmask0,
  282. dmask1,
  283. mu,
  284. rsigma,
  285. gamma0,
  286. gamma1,
  287. dropout_p,
  288. has_x1,
  289. has_residual,
  290. is_rms_norm,
  291. )
  292. # dresidualmat is None if not has_residual
  293. return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
  294. class DropoutAddLayerNormFn(torch.autograd.Function):
  295. @staticmethod
  296. def forward(
  297. ctx,
  298. x0,
  299. residual,
  300. gamma,
  301. beta,
  302. rowscale,
  303. colscale,
  304. dropout_p,
  305. epsilon,
  306. residual_in_fp32=False,
  307. prenorm=False,
  308. is_rms_norm=False,
  309. return_dmask=False,
  310. ):
  311. x0 = maybe_align(x0.contiguous(), 16)
  312. residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
  313. gamma = maybe_align(gamma.contiguous(), 16)
  314. beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
  315. rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
  316. colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
  317. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  318. x0,
  319. residual,
  320. gamma,
  321. beta,
  322. rowscale,
  323. colscale,
  324. dropout_p,
  325. epsilon,
  326. residual_in_fp32,
  327. is_rms_norm,
  328. )
  329. # Only need to save x0 if we need to compute gradient wrt colscale
  330. x0_saved = x0 if colscale is not None else None
  331. ctx.save_for_backward(
  332. xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
  333. )
  334. ctx.prenorm = prenorm
  335. ctx.dropout_p = dropout_p
  336. ctx.has_residual = residual is not None
  337. ctx.is_rms_norm = is_rms_norm
  338. ctx.has_beta = beta is not None
  339. if not return_dmask:
  340. return (
  341. zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
  342. )
  343. else:
  344. dmask = (
  345. dmask.view(x0.shape)
  346. if dropout_p > 0.0
  347. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
  348. )
  349. ctx.mark_non_differentiable(dmask)
  350. return (
  351. (zmat.view(x0.shape), dmask)
  352. if not prenorm
  353. else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
  354. )
  355. @staticmethod
  356. def backward(ctx, dz, *args):
  357. # assert dz.is_contiguous()
  358. dz = maybe_align(dz.contiguous(), 16) # this happens!
  359. dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
  360. x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
  361. # x0 is None if colscale is None
  362. dropout_p = ctx.dropout_p
  363. has_residual = ctx.has_residual
  364. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
  365. dz,
  366. dx,
  367. x,
  368. x0,
  369. dmask,
  370. mu,
  371. rsigma,
  372. gamma,
  373. rowscale,
  374. colscale,
  375. dropout_p,
  376. has_residual,
  377. ctx.is_rms_norm,
  378. )
  379. dx0 = dx0mat.view(x.shape)
  380. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  381. dcolscale = rest[0] if colscale is not None else None
  382. return (
  383. dx0,
  384. dresidual,
  385. dgamma,
  386. dbeta if ctx.has_beta else None,
  387. None,
  388. dcolscale,
  389. None,
  390. None,
  391. None,
  392. None,
  393. None,
  394. None,
  395. )
  396. class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
  397. @staticmethod
  398. def forward(
  399. ctx,
  400. x0,
  401. residual,
  402. gamma,
  403. beta,
  404. colscale,
  405. x0_subset,
  406. out_subset,
  407. dropout_p,
  408. epsilon,
  409. rowscale_const,
  410. out_numrows,
  411. residual_in_fp32=False,
  412. prenorm=False,
  413. is_rms_norm=False,
  414. return_dmask=False,
  415. ):
  416. x0 = maybe_align(x0.contiguous(), 16)
  417. residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
  418. gamma = maybe_align(gamma.contiguous(), 16)
  419. beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
  420. colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
  421. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
  422. x0,
  423. residual,
  424. gamma,
  425. beta,
  426. colscale,
  427. x0_subset,
  428. out_subset,
  429. dropout_p,
  430. epsilon,
  431. rowscale_const,
  432. out_numrows,
  433. residual_in_fp32,
  434. is_rms_norm,
  435. )
  436. # Only need to save x0 if we need to compute gradient wrt colscale
  437. x0_saved = x0 if colscale is not None else None
  438. x_shape = (-1, *x0.shape[1:])
  439. ctx.save_for_backward(
  440. xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
  441. )
  442. ctx.prenorm = prenorm
  443. ctx.dropout_p = dropout_p
  444. ctx.rowscale_const = rowscale_const
  445. ctx.x0_numrows = x0.shape[:-1].numel()
  446. ctx.has_residual = residual is not None
  447. ctx.is_rms_norm = is_rms_norm
  448. ctx.has_beta = beta is not None
  449. z_shape = (-1, *x0.shape[1:])
  450. if not return_dmask:
  451. return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
  452. else:
  453. z = zmat.view(z_shape)
  454. dmask = (
  455. dmask.view(x0.shape)
  456. if dropout_p > 0.0
  457. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
  458. )
  459. ctx.mark_non_differentiable(dmask)
  460. return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
  461. @staticmethod
  462. def backward(ctx, dz, *args):
  463. # assert dz.is_contiguous()
  464. dz = maybe_align(dz.contiguous(), 16) # this happens!
  465. dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
  466. x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
  467. # x0 is None if colscale is None
  468. dropout_p = ctx.dropout_p
  469. has_residual = ctx.has_residual
  470. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
  471. dz,
  472. dx,
  473. x,
  474. x0,
  475. dmask,
  476. mu,
  477. rsigma,
  478. gamma,
  479. colscale,
  480. x0_subset,
  481. out_subset,
  482. dropout_p,
  483. ctx.rowscale_const,
  484. ctx.x0_numrows,
  485. has_residual,
  486. ctx.is_rms_norm,
  487. )
  488. dx0 = dx0mat.view(-1, *x.shape[1:])
  489. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  490. dcolscale = rest[0] if colscale is not None else None
  491. return (
  492. dx0,
  493. dresidual,
  494. dgamma,
  495. dbeta if ctx.has_beta else None,
  496. dcolscale,
  497. None,
  498. None,
  499. None,
  500. None,
  501. None,
  502. None,
  503. None,
  504. None,
  505. None,
  506. None,
  507. )
  508. class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
  509. @staticmethod
  510. def forward(
  511. ctx,
  512. x0,
  513. x1,
  514. residual,
  515. gamma0,
  516. beta0,
  517. gamma1,
  518. beta1,
  519. dropout_p,
  520. epsilon,
  521. residual_in_fp32=False,
  522. prenorm=False,
  523. is_rms_norm=False,
  524. return_dmask=False,
  525. ):
  526. x0 = maybe_align(x0.contiguous(), 16)
  527. x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
  528. residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
  529. gamma0 = maybe_align(gamma0.contiguous(), 16)
  530. beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
  531. gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
  532. beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
  533. (
  534. z0mat,
  535. z1mat,
  536. xmat,
  537. dmask0,
  538. dmask1,
  539. mu,
  540. rsigma,
  541. ) = _dropout_add_layer_norm_parallel_residual_forward(
  542. x0,
  543. x1,
  544. residual,
  545. gamma0,
  546. beta0,
  547. gamma1,
  548. beta1,
  549. dropout_p,
  550. epsilon,
  551. residual_in_fp32,
  552. is_rms_norm,
  553. )
  554. ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
  555. ctx.prenorm = prenorm
  556. ctx.dropout_p = dropout_p
  557. ctx.has_x1 = x1 is not None
  558. ctx.has_residual = residual is not None
  559. ctx.is_rms_norm = is_rms_norm
  560. ctx.has_beta = beta0 is not None
  561. z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
  562. if not return_dmask:
  563. return z if not prenorm else (*z, xmat.view(x0.shape))
  564. else:
  565. dmask0 = (
  566. dmask0.view(x0.shape)
  567. if dropout_p > 0.0
  568. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
  569. )
  570. dmask1 = (
  571. dmask1.view(x0.shape)
  572. if dropout_p > 0.0 and x1 is not None
  573. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
  574. )
  575. ctx.mark_non_differentiable(dmask0)
  576. ctx.mark_non_differentiable(dmask1)
  577. return (
  578. (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
  579. )
  580. @staticmethod
  581. def backward(ctx, dz0, dz1, *args):
  582. dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
  583. dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
  584. dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
  585. x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
  586. dropout_p = ctx.dropout_p
  587. has_x1 = ctx.has_x1
  588. has_residual = ctx.has_residual
  589. (
  590. dx0mat,
  591. dx1mat,
  592. dresidualmat,
  593. dgamma0,
  594. dbeta0,
  595. dgamma1,
  596. dbeta1,
  597. ) = _dropout_add_layer_norm_parallel_residual_backward(
  598. dz0,
  599. dz1,
  600. dx,
  601. x,
  602. dmask0,
  603. dmask1,
  604. mu,
  605. rsigma,
  606. gamma0,
  607. gamma1,
  608. dropout_p,
  609. has_x1,
  610. has_residual,
  611. ctx.is_rms_norm,
  612. )
  613. dx0 = dx0mat.view(x.shape)
  614. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  615. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  616. return (
  617. dx0,
  618. dx1,
  619. dresidual,
  620. dgamma0,
  621. dbeta0 if ctx.has_beta else None,
  622. dgamma1,
  623. dbeta1 if ctx.has_beta else None,
  624. None,
  625. None,
  626. None,
  627. None,
  628. None,
  629. None,
  630. )
  631. def layer_norm(x, weight, bias, epsilon):
  632. return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
  633. def dropout_add_layer_norm(
  634. x0,
  635. residual,
  636. weight,
  637. bias,
  638. dropout_p,
  639. epsilon,
  640. rowscale=None,
  641. layerscale=None,
  642. prenorm=False,
  643. residual_in_fp32=False,
  644. return_dropout_mask=False,
  645. ):
  646. """residual_in_fp32 only has an effect if residual is None.
  647. Otherwise residual dtype is residual.dtype.
  648. """
  649. return DropoutAddLayerNormFn.apply(
  650. x0,
  651. residual,
  652. weight,
  653. bias,
  654. rowscale,
  655. layerscale,
  656. dropout_p,
  657. epsilon,
  658. residual_in_fp32,
  659. prenorm,
  660. False,
  661. return_dropout_mask,
  662. )
  663. def dropout_add_layer_norm_subset(
  664. x0,
  665. residual,
  666. weight,
  667. bias,
  668. dropout_p,
  669. epsilon,
  670. layerscale=None,
  671. x0_subset=None,
  672. out_subset=None,
  673. rowscale_const=1.0,
  674. out_numrows=0,
  675. prenorm=False,
  676. residual_in_fp32=False,
  677. return_dropout_mask=False,
  678. ):
  679. """residual_in_fp32 only has an effect if residual is None.
  680. Otherwise residual dtype is residual.dtype.
  681. """
  682. return DropoutAddLayerNormSubsetFn.apply(
  683. x0,
  684. residual,
  685. weight,
  686. bias,
  687. layerscale,
  688. x0_subset,
  689. out_subset,
  690. dropout_p,
  691. epsilon,
  692. rowscale_const,
  693. out_numrows,
  694. residual_in_fp32,
  695. prenorm,
  696. False,
  697. return_dropout_mask,
  698. )
  699. def dropout_add_layer_norm_parallel_residual(
  700. x0,
  701. x1,
  702. residual,
  703. weight0,
  704. bias0,
  705. weight1,
  706. bias1,
  707. dropout_p,
  708. epsilon,
  709. prenorm=False,
  710. residual_in_fp32=False,
  711. return_dropout_mask=False,
  712. ):
  713. """residual_in_fp32 only has an effect if residual is None.
  714. Otherwise residual dtype is residual.dtype.
  715. """
  716. return DropoutAddLayerNormParallelResidualFn.apply(
  717. x0,
  718. x1,
  719. residual,
  720. weight0,
  721. bias0,
  722. weight1,
  723. bias1,
  724. dropout_p,
  725. epsilon,
  726. residual_in_fp32,
  727. prenorm,
  728. False,
  729. return_dropout_mask,
  730. )
  731. class DropoutAddLayerNorm(torch.nn.Module):
  732. def __init__(
  733. self,
  734. hidden_size,
  735. prenorm=False,
  736. p=0.0,
  737. eps=1e-5,
  738. residual_in_fp32=False,
  739. device=None,
  740. dtype=None,
  741. ):
  742. factory_kwargs = {"device": device, "dtype": dtype}
  743. super().__init__()
  744. self.prenorm = prenorm
  745. self.p = p
  746. self.eps = eps
  747. self.residual_in_fp32 = residual_in_fp32
  748. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  749. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  750. self.reset_parameters()
  751. def reset_parameters(self):
  752. init.ones_(self.weight)
  753. init.zeros_(self.bias)
  754. def forward(self, x0, residual=None):
  755. return dropout_add_layer_norm(
  756. x0,
  757. residual,
  758. self.weight,
  759. self.bias,
  760. self.p if self.training else 0.0,
  761. self.eps,
  762. prenorm=self.prenorm,
  763. residual_in_fp32=self.residual_in_fp32,
  764. )