layer_norm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import torch
  4. from torch.nn import init
  5. import dropout_layer_norm
  6. def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
  7. epsilon, residual_in_fp32=False, is_rms_norm=False):
  8. """ Assume that arguments are contiguous
  9. """
  10. hidden_size = gamma.numel()
  11. x0mat = x0.view((-1, hidden_size))
  12. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  13. rowscale = rowscale.view(-1) if rowscale is not None else None
  14. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  15. x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
  16. 1.0, 0, None, residual_in_fp32, is_rms_norm
  17. )
  18. # dmask is None if dropout_p == 0.0
  19. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  20. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  21. def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
  22. dropout_p, has_residual, is_rms_norm=False):
  23. """ Assume that arguments are contiguous
  24. dx == None means that it was a post-norm architecture
  25. (x = drop(x0) + residual was not returned in the fwd).
  26. x0 must not be None if we have colscale.
  27. """
  28. hidden_size = gamma.numel()
  29. xmat = x.view((-1, hidden_size))
  30. dzmat = dz.view(xmat.shape)
  31. dxmat = dx.view(xmat.shape) if dx is not None else None
  32. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  33. rowscale = rowscale.view(-1) if rowscale is not None else None
  34. if colscale is not None:
  35. assert x0 is not None, 'x0 is required to compute the gradient of colscale'
  36. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  37. dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
  38. dropout_p, 1.0, 0, has_residual, is_rms_norm
  39. )
  40. # dresidualmat is None if not has_residual
  41. if colscale is None:
  42. return dx0mat, dresidualmat, dgamma, dbeta
  43. else:
  44. dcolscale = rest[0]
  45. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  46. def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
  47. out_subset, dropout_p, epsilon, rowscale_const,
  48. out_numrows, residual_in_fp32=False, is_rms_norm=False):
  49. """ Assume that arguments are contiguous
  50. """
  51. hidden_size = gamma.numel()
  52. x0mat = x0.view((-1, hidden_size))
  53. residualmat = residual.view((-1, hidden_size)) if residual is not None else None
  54. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  55. out_subset = out_subset.view(-1) if out_subset is not None else None
  56. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  57. x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
  58. rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
  59. )
  60. # dmask is None if dropout_p == 0.0
  61. # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
  62. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  63. def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
  64. x0_subset, out_subset, dropout_p, rowscale_const,
  65. x0_numrows, has_residual, is_rms_norm=False):
  66. """ Assume that arguments are contiguous
  67. dx == None means that it was a post-norm architecture
  68. (x = drop(x0) + residual was not returned in the fwd).
  69. x0 must not be None if we have colscale.
  70. """
  71. hidden_size = gamma.numel()
  72. xmat = x.view((-1, hidden_size))
  73. dzmat = dz.view(-1, hidden_size)
  74. dxmat = dx.view(xmat.shape) if dx is not None else None
  75. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  76. x0_subset = x0_subset.view(-1) if x0_subset is not None else None
  77. out_subset = out_subset.view(-1) if out_subset is not None else None
  78. if colscale is not None:
  79. assert x0 is not None, 'x0 is required to compute the gradient of colscale'
  80. dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  81. dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
  82. dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
  83. )
  84. # dresidualmat is None if not has_residual
  85. if colscale is None:
  86. return dx0mat, dresidualmat, dgamma, dbeta
  87. else:
  88. dcolscale = rest[0]
  89. return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
  90. class DropoutAddLayerNormFn(torch.autograd.Function):
  91. @staticmethod
  92. def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
  93. residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
  94. x0 = x0.contiguous()
  95. residual = residual.contiguous() if residual is not None else None
  96. gamma = gamma.contiguous()
  97. beta = beta.contiguous() if beta is not None else None
  98. rowscale = rowscale.contiguous() if rowscale is not None else None
  99. colscale = colscale.contiguous() if colscale is not None else None
  100. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  101. x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
  102. residual_in_fp32, is_rms_norm
  103. )
  104. # Only need to save x0 if we need to compute gradient wrt colscale
  105. x0_saved = x0 if colscale is not None else None
  106. ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
  107. ctx.prenorm = prenorm
  108. ctx.dropout_p = dropout_p
  109. ctx.has_residual = residual is not None
  110. ctx.is_rms_norm = is_rms_norm
  111. ctx.has_beta = beta is not None
  112. if not return_dmask:
  113. return (zmat.view(x0.shape) if not prenorm
  114. else (zmat.view(x0.shape), xmat.view(x0.shape)))
  115. else:
  116. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  117. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  118. ctx.mark_non_differentiable(dmask)
  119. return ((zmat.view(x0.shape), dmask) if not prenorm
  120. else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
  121. @staticmethod
  122. def backward(ctx, dz, *args):
  123. # assert dz.is_contiguous()
  124. dz = dz.contiguous() # this happens!
  125. dx = args[0].contiguous() if ctx.prenorm else None
  126. x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
  127. # x0 is None if colscale is None
  128. dropout_p = ctx.dropout_p
  129. has_residual = ctx.has_residual
  130. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
  131. dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
  132. ctx.is_rms_norm
  133. )
  134. dx0 = dx0mat.view(x.shape)
  135. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  136. dcolscale = rest[0] if colscale is not None else None
  137. return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
  138. None, None, None, None, None)
  139. class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
  140. @staticmethod
  141. def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
  142. rowscale_const, out_numrows, residual_in_fp32=False,
  143. prenorm=False, is_rms_norm=False, return_dmask=False):
  144. x0 = x0.contiguous()
  145. residual = residual.contiguous() if residual is not None else None
  146. gamma = gamma.contiguous()
  147. beta = beta.contiguous() if beta is not None else None
  148. colscale = colscale.contiguous() if colscale is not None else None
  149. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
  150. x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
  151. rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
  152. )
  153. # Only need to save x0 if we need to compute gradient wrt colscale
  154. x0_saved = x0 if colscale is not None else None
  155. x_shape = (-1, *x0.shape[1:])
  156. ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
  157. x0_subset, out_subset)
  158. ctx.prenorm = prenorm
  159. ctx.dropout_p = dropout_p
  160. ctx.rowscale_const = rowscale_const
  161. ctx.x0_numrows = x0.shape[:-1].numel()
  162. ctx.has_residual = residual is not None
  163. ctx.is_rms_norm = is_rms_norm
  164. ctx.has_beta = beta is not None
  165. z_shape = (-1, *x0.shape[1:])
  166. if not return_dmask:
  167. return (zmat.view(z_shape) if not prenorm
  168. else (zmat.view(z_shape), xmat.view(x0.shape)))
  169. else:
  170. z = zmat.view(z_shape)
  171. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  172. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  173. ctx.mark_non_differentiable(dmask)
  174. return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
  175. @staticmethod
  176. def backward(ctx, dz, *args):
  177. # assert dz.is_contiguous()
  178. dz = dz.contiguous() # this happens!
  179. dx = args[0].contiguous() if ctx.prenorm else None
  180. x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
  181. # x0 is None if colscale is None
  182. dropout_p = ctx.dropout_p
  183. has_residual = ctx.has_residual
  184. dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
  185. dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
  186. ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
  187. )
  188. dx0 = dx0mat.view(-1, *x.shape[1:])
  189. dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
  190. dcolscale = rest[0] if colscale is not None else None
  191. return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
  192. None, None, None, None, None, None, None, None)
  193. def layer_norm(x, weight, bias, epsilon):
  194. return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
  195. def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
  196. layerscale=None, prenorm=False, residual_in_fp32=False,
  197. return_dropout_mask=False):
  198. """residual_in_fp32 only has an effect if residual is None.
  199. Otherwise residual dtype is residual.dtype.
  200. """
  201. return DropoutAddLayerNormFn.apply(
  202. x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
  203. False, return_dropout_mask
  204. )
  205. def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
  206. x0_subset=None, out_subset=None, rowscale_const=1.0,
  207. out_numrows=0, prenorm=False, residual_in_fp32=False,
  208. return_dropout_mask=False):
  209. """residual_in_fp32 only has an effect if residual is None.
  210. Otherwise residual dtype is residual.dtype.
  211. """
  212. return DropoutAddLayerNormSubsetFn.apply(
  213. x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
  214. rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
  215. )
  216. class DropoutAddLayerNorm(torch.nn.Module):
  217. def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
  218. device=None, dtype=None):
  219. factory_kwargs = {'device': device, 'dtype': dtype}
  220. super().__init__()
  221. self.prenorm = prenorm
  222. self.p = p
  223. self.epsilon = eps
  224. self.residual_in_fp32 = residual_in_fp32
  225. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  226. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  227. self.reset_parameters()
  228. def reset_parameters(self):
  229. init.ones_(self.weight)
  230. init.zeros_(self.bias)
  231. def forward(self, x0, residual=None):
  232. return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
  233. self.p if self.training else 0.0, self.epsilon,
  234. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)