layer_norm.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import torch
  4. from torch.nn import init
  5. import dropout_layer_norm
  6. def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
  7. epsilon, residual_in_fp32=False, is_rms_norm=False):
  8. """ Assume that arguments are contiguous
  9. """
  10. hidden_size = gamma.numel()
  11. x0mat = x0.view((-1, hidden_size))
  12. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  13. rowscale = rowscale.view(-1) if rowscale is not None else None
  14. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  15. x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
  16. 1.0, 0, None, residual_in_fp32, is_rms_norm
  17. )
  18. # dmask is None if dropout_p == 0.0
  19. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  20. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  21. def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
  22. dropout_p, has_residual, is_rms_norm=False):
  23. """ Assume that arguments are contiguous
  24. dx == None means that it was a post-norm architecture
  25. (x = drop(x0) + residual was not returned in the fwd).
  26. x0 must not be None if we have colscale.
  27. """
  28. hidden_size = gamma.numel()
  29. xmat = x.view((-1, hidden_size))
  30. dzmat = dz.view(xmat.shape)
  31. dxmat = dx.view(xmat.shape) if dx is not None else None
  32. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  33. rowscale = rowscale.view(-1) if rowscale is not None else None
  34. if colscale is not None:
  35. assert x0 is not None, 'x0 is required to compute the gradient of colscale'
  36. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  37. dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
  38. dropout_p, 1.0, 0, has_residual, is_rms_norm
  39. )
  40. # dresidualmat is None if not has_residual
  41. if colscale is None:
  42. return dx0mat, dresidualmat, dgamma, dbeta
  43. else:
  44. dcolscale = rest[0]
  45. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  46. def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
  47. out_subset, dropout_p, epsilon, rowscale_const,
  48. out_numrows, residual_in_fp32=False, is_rms_norm=False):
  49. """ Assume that arguments are contiguous
  50. """
  51. hidden_size = gamma.numel()
  52. x0mat = x0.view((-1, hidden_size))
  53. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  54. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  55. out_subset = out_subset.view(-1) if out_subset is not None else None
  56. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  57. x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
  58. rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
  59. )
  60. # dmask is None if dropout_p == 0.0
  61. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  62. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  63. def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
  64. x0_subset, out_subset, dropout_p, rowscale_const,
  65. x0_numrows, has_residual, is_rms_norm=False):
  66. """ Assume that arguments are contiguous
  67. dx == None means that it was a post-norm architecture
  68. (x = drop(x0) + residual was not returned in the fwd).
  69. x0 must not be None if we have colscale.
  70. """
  71. hidden_size = gamma.numel()
  72. xmat = x.view((-1, hidden_size))
  73. dzmat = dz.view(-1, hidden_size)
  74. dxmat = dx.view(xmat.shape) if dx is not None else None
  75. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  76. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  77. out_subset = out_subset.view(-1) if out_subset is not None else None
  78. if colscale is not None:
  79. assert x0 is not None, 'x0 is required to compute the gradient of colscale'
  80. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  81. dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
  82. dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
  83. )
  84. # dresidualmat is None if not has_residual
  85. if colscale is None:
  86. return dx0mat, dresidualmat, dgamma, dbeta
  87. else:
  88. dcolscale = rest[0]
  89. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  90. def _dropout_add_layer_norm_parallel_residual_forward(
  91. x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
  92. epsilon, residual_in_fp32=False, is_rms_norm=False
  93. ):
  94. """ Assume that arguments are contiguous
  95. """
  96. hidden_size = gamma0.numel()
  97. x0mat = x0.view((-1, hidden_size))
  98. x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
  99. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  100. z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
  101. x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
  102. None, residual_in_fp32, is_rms_norm
  103. )
  104. # dmask0 and dmask1 are None if dropout_p == 0.0
  105. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  106. return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
  107. def _dropout_add_layer_norm_parallel_residual_backward(
  108. dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
  109. dropout_p, has_x1, has_residual, is_rms_norm=False
  110. ):
  111. """ Assume that arguments are contiguous
  112. dx == None means that it was a post-norm architecture
  113. (x = drop(x0) + residual was not returned in the fwd).
  114. """
  115. hidden_size = gamma0.numel()
  116. xmat = x.view((-1, hidden_size))
  117. dz0mat = dz0.view(xmat.shape)
  118. dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
  119. dxmat = dx.view(xmat.shape) if dx is not None else None
  120. dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
  121. dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
  122. dropout_p, has_x1, has_residual, is_rms_norm
  123. )
  124. # dresidualmat is None if not has_residual
  125. return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
  126. class DropoutAddLayerNormFn(torch.autograd.Function):
  127. @staticmethod
  128. def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
  129. residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
  130. x0 = x0.contiguous()
  131. residual = residual.contiguous() if residual is not None else None
  132. gamma = gamma.contiguous()
  133. beta = beta.contiguous() if beta is not None else None
  134. rowscale = rowscale.contiguous() if rowscale is not None else None
  135. colscale = colscale.contiguous() if colscale is not None else None
  136. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  137. x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
  138. residual_in_fp32, is_rms_norm
  139. )
  140. # Only need to save x0 if we need to compute gradient wrt colscale
  141. x0_saved = x0 if colscale is not None else None
  142. ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
  143. ctx.prenorm = prenorm
  144. ctx.dropout_p = dropout_p
  145. ctx.has_residual = residual is not None
  146. ctx.is_rms_norm = is_rms_norm
  147. ctx.has_beta = beta is not None
  148. if not return_dmask:
  149. return (zmat.view(x0.shape) if not prenorm
  150. else (zmat.view(x0.shape), xmat.view(x0.shape)))
  151. else:
  152. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  153. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  154. ctx.mark_non_differentiable(dmask)
  155. return ((zmat.view(x0.shape), dmask) if not prenorm
  156. else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
  157. @staticmethod
  158. def backward(ctx, dz, *args):
  159. # assert dz.is_contiguous()
  160. dz = dz.contiguous() # this happens!
  161. dx = args[0].contiguous() if ctx.prenorm else None
  162. x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
  163. # x0 is None if colscale is None
  164. dropout_p = ctx.dropout_p
  165. has_residual = ctx.has_residual
  166. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
  167. dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
  168. ctx.is_rms_norm
  169. )
  170. dx0 = dx0mat.view(x.shape)
  171. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  172. dcolscale = rest[0] if colscale is not None else None
  173. return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
  174. None, None, None, None, None)
  175. class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
  176. @staticmethod
  177. def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
  178. rowscale_const, out_numrows, residual_in_fp32=False,
  179. prenorm=False, is_rms_norm=False, return_dmask=False):
  180. x0 = x0.contiguous()
  181. residual = residual.contiguous() if residual is not None else None
  182. gamma = gamma.contiguous()
  183. beta = beta.contiguous() if beta is not None else None
  184. colscale = colscale.contiguous() if colscale is not None else None
  185. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
  186. x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
  187. rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
  188. )
  189. # Only need to save x0 if we need to compute gradient wrt colscale
  190. x0_saved = x0 if colscale is not None else None
  191. x_shape = (-1, *x0.shape[1:])
  192. ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
  193. x0_subset, out_subset)
  194. ctx.prenorm = prenorm
  195. ctx.dropout_p = dropout_p
  196. ctx.rowscale_const = rowscale_const
  197. ctx.x0_numrows = x0.shape[:-1].numel()
  198. ctx.has_residual = residual is not None
  199. ctx.is_rms_norm = is_rms_norm
  200. ctx.has_beta = beta is not None
  201. z_shape = (-1, *x0.shape[1:])
  202. if not return_dmask:
  203. return (zmat.view(z_shape) if not prenorm
  204. else (zmat.view(z_shape), xmat.view(x0.shape)))
  205. else:
  206. z = zmat.view(z_shape)
  207. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  208. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  209. ctx.mark_non_differentiable(dmask)
  210. return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
  211. @staticmethod
  212. def backward(ctx, dz, *args):
  213. # assert dz.is_contiguous()
  214. dz = dz.contiguous() # this happens!
  215. dx = args[0].contiguous() if ctx.prenorm else None
  216. x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
  217. # x0 is None if colscale is None
  218. dropout_p = ctx.dropout_p
  219. has_residual = ctx.has_residual
  220. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
  221. dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
  222. ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
  223. )
  224. dx0 = dx0mat.view(-1, *x.shape[1:])
  225. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  226. dcolscale = rest[0] if colscale is not None else None
  227. return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
  228. None, None, None, None, None, None, None, None)
  229. class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
  230. @staticmethod
  231. def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
  232. residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
  233. x0 = x0.contiguous()
  234. x1 = x1.contiguous() if x1 is not None else None
  235. residual = residual.contiguous() if residual is not None else None
  236. gamma0 = gamma0.contiguous()
  237. beta0 = beta0.contiguous() if beta0 is not None else None
  238. gamma1 = gamma1.contiguous() if gamma1 is not None else None
  239. beta1 = beta1.contiguous() if beta1 is not None else None
  240. z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
  241. x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
  242. residual_in_fp32, is_rms_norm
  243. )
  244. ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
  245. ctx.prenorm = prenorm
  246. ctx.dropout_p = dropout_p
  247. ctx.has_x1 = x1 is not None
  248. ctx.has_residual = residual is not None
  249. ctx.is_rms_norm = is_rms_norm
  250. ctx.has_beta = beta0 is not None
  251. z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
  252. if not return_dmask:
  253. return z if not prenorm else (*z, xmat.view(x0.shape))
  254. else:
  255. dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
  256. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  257. dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
  258. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  259. ctx.mark_non_differentiable(dmask0)
  260. ctx.mark_non_differentiable(dmask1)
  261. return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
  262. @staticmethod
  263. def backward(ctx, dz0, dz1, *args):
  264. dz0 = dz0.contiguous() # this happens!
  265. dz1 = dz1.contiguous() if dz1 is not None else None
  266. dx = args[0].contiguous() if ctx.prenorm else None
  267. x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
  268. dropout_p = ctx.dropout_p
  269. has_x1 = ctx.has_x1
  270. has_residual = ctx.has_residual
  271. dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
  272. dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
  273. has_residual, ctx.is_rms_norm
  274. )
  275. dx0 = dx0mat.view(x.shape)
  276. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  277. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  278. return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
  279. dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
  280. def layer_norm(x, weight, bias, epsilon):
  281. return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
  282. def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
  283. layerscale=None, prenorm=False, residual_in_fp32=False,
  284. return_dropout_mask=False):
  285. """residual_in_fp32 only has an effect if residual is None.
  286. Otherwise residual dtype is residual.dtype.
  287. """
  288. return DropoutAddLayerNormFn.apply(
  289. x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
  290. False, return_dropout_mask
  291. )
  292. def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
  293. x0_subset=None, out_subset=None, rowscale_const=1.0,
  294. out_numrows=0, prenorm=False, residual_in_fp32=False,
  295. return_dropout_mask=False):
  296. """residual_in_fp32 only has an effect if residual is None.
  297. Otherwise residual dtype is residual.dtype.
  298. """
  299. return DropoutAddLayerNormSubsetFn.apply(
  300. x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
  301. rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
  302. )
  303. def dropout_add_layer_norm_parallel_residual(
  304. x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
  305. residual_in_fp32=False, return_dropout_mask=False
  306. ):
  307. """residual_in_fp32 only has an effect if residual is None.
  308. Otherwise residual dtype is residual.dtype.
  309. """
  310. return DropoutAddLayerNormParallelResidualFn.apply(
  311. x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
  312. False, return_dropout_mask
  313. )
  314. class DropoutAddLayerNorm(torch.nn.Module):
  315. def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
  316. device=None, dtype=None):
  317. factory_kwargs = {'device': device, 'dtype': dtype}
  318. super().__init__()
  319. self.prenorm = prenorm
  320. self.p = p
  321. self.epsilon = eps
  322. self.residual_in_fp32 = residual_in_fp32
  323. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  324. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  325. self.reset_parameters()
  326. def reset_parameters(self):
  327. init.ones_(self.weight)
  328. init.zeros_(self.bias)
  329. def forward(self, x0, residual=None):
  330. return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
  331. self.p if self.training else 0.0, self.epsilon,
  332. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)