layer_norm.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  2. import torch
  3. from torch.nn import init
  4. # from apex._autocast_utils import _cast_if_autocast_enabled
  5. import dropout_layer_norm
  6. def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
  7. residual_in_fp32):
  8. """ Assume that arguments are contiguous
  9. """
  10. hidden_size = gamma.numel()
  11. x0mat = x0.view((-1, hidden_size))
  12. x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
  13. rowscale = rowscale.view(-1) if rowscale is not None else None
  14. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  15. x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
  16. )
  17. # dmask is None if dropout_p == 0.0
  18. # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
  19. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  20. def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
  21. has_residual):
  22. """ Assume that arguments are contiguous
  23. """
  24. # dmask is None if dropout_p == 0.0
  25. hidden_size = gamma.numel()
  26. xmat = x.view((-1, hidden_size))
  27. dzmat = dz.view(xmat.shape)
  28. rowscale = rowscale.view(-1) if rowscale is not None else None
  29. dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
  30. dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  31. )
  32. # dx1mat is None if not has_residual
  33. return dx0mat, dx1mat, dgamma, dbeta
  34. def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
  35. dropout_p, has_residual):
  36. """ Assume that arguments are contiguous
  37. """
  38. hidden_size = gamma.numel()
  39. xmat = x.view((-1, hidden_size))
  40. dzmat = dz.view(xmat.shape)
  41. dxmat = dx.view(xmat.shape)
  42. rowscale = rowscale.view(-1) if rowscale is not None else None
  43. dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
  44. dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  45. )
  46. return dx0mat, dx1mat, dgamma, dbeta
  47. class DropoutAddLayerNormFN(torch.autograd.Function):
  48. @staticmethod
  49. def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
  50. return_dmask=False):
  51. x0 = x0.contiguous()
  52. x1 = x1.contiguous() if x1 is not None else None
  53. gamma = gamma.contiguous()
  54. beta = beta.contiguous()
  55. rowscale = rowscale.contiguous() if rowscale is not None else None
  56. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  57. x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
  58. )
  59. ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
  60. ctx.dropout_p = dropout_p
  61. ctx.has_residual = x1 is not None
  62. if not return_dmask:
  63. return zmat.view(x0.shape)
  64. else:
  65. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  66. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  67. ctx.mark_non_differentiable(dmask)
  68. return zmat.view(x0.shape), dmask
  69. @staticmethod
  70. def backward(ctx, dz, *args):
  71. # assert dz.is_contiguous()
  72. dz = dz.contiguous() # this happens!
  73. x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
  74. dropout_p = ctx.dropout_p
  75. has_residual = ctx.has_residual
  76. dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
  77. dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  78. )
  79. dx0 = dx0mat.view(x.shape)
  80. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  81. return dx0, dx1, dgamma, dbeta, None, None, None, None, None
  82. class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
  83. @staticmethod
  84. def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
  85. return_dmask=False):
  86. x0 = x0.contiguous()
  87. x1 = x1.contiguous() if x1 is not None else None
  88. gamma = gamma.contiguous()
  89. beta = beta.contiguous()
  90. rowscale = rowscale.contiguous() if rowscale is not None else None
  91. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  92. x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
  93. )
  94. ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
  95. ctx.dropout_p = dropout_p
  96. ctx.has_residual = x1 is not None
  97. if not return_dmask:
  98. return zmat.view(x0.shape), xmat.view(x0.shape)
  99. else:
  100. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  101. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  102. ctx.mark_non_differentiable(dmask)
  103. return zmat.view(x0.shape), xmat.view(x0.shape), dmask
  104. @staticmethod
  105. def backward(ctx, dz, dx, *args):
  106. # assert dz.is_contiguous()
  107. dz = dz.contiguous() # this happens!
  108. dx = dx.contiguous() # this happens!
  109. x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
  110. dropout_p = ctx.dropout_p
  111. has_residual = ctx.has_residual
  112. dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
  113. dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  114. )
  115. dx0 = dx0mat.view(x.shape)
  116. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  117. return dx0, dx1, dgamma, dbeta, None, None, None, None, None
  118. def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
  119. prenorm=False, residual_in_fp32=False,
  120. return_dropout_mask=False):
  121. """residual_in_fp32 only has an effect if x1 is None.
  122. Otherwise residual dtype is x1.dtype.
  123. """
  124. args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
  125. return_dropout_mask)
  126. if not prenorm:
  127. return DropoutAddLayerNormFN.apply(*args)
  128. else:
  129. return DropoutAddLayerNormPrenormFN.apply(*args)
  130. class DropoutAddLayerNorm(torch.nn.Module):
  131. def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False,
  132. device=None, dtype=None):
  133. factory_kwargs = {'device': device, 'dtype': dtype}
  134. super().__init__()
  135. self.prenorm = prenorm
  136. self.p = p
  137. self.epsilon = eps
  138. self.residual_in_fp32 = residual_in_fp32
  139. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  140. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  141. self.reset_parameters()
  142. def reset_parameters(self):
  143. init.ones_(self.weight)
  144. init.zeros_(self.bias)
  145. def forward(self, x0, x1=None):
  146. return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
  147. self.p if self.training else 0.0, self.epsilon,
  148. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)