layer_norm.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  2. import torch
  3. from torch.nn import init
  4. import dropout_layer_norm
  5. def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
  6. residual_in_fp32):
  7. """ Assume that arguments are contiguous
  8. """
  9. hidden_size = gamma.numel()
  10. x0mat = x0.view((-1, hidden_size))
  11. x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
  12. rowscale = rowscale.view(-1) if rowscale is not None else None
  13. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  14. x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
  15. )
  16. # dmask is None if dropout_p == 0.0
  17. # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
  18. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  19. def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
  20. has_residual):
  21. """ Assume that arguments are contiguous
  22. """
  23. # dmask is None if dropout_p == 0.0
  24. hidden_size = gamma.numel()
  25. xmat = x.view((-1, hidden_size))
  26. dzmat = dz.view(xmat.shape)
  27. rowscale = rowscale.view(-1) if rowscale is not None else None
  28. dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
  29. dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  30. )
  31. # dx1mat is None if not has_residual
  32. return dx0mat, dx1mat, dgamma, dbeta
  33. def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
  34. dropout_p, has_residual):
  35. """ Assume that arguments are contiguous
  36. """
  37. hidden_size = gamma.numel()
  38. xmat = x.view((-1, hidden_size))
  39. dzmat = dz.view(xmat.shape)
  40. dxmat = dx.view(xmat.shape)
  41. rowscale = rowscale.view(-1) if rowscale is not None else None
  42. dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
  43. dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  44. )
  45. return dx0mat, dx1mat, dgamma, dbeta
  46. class DropoutAddLayerNormFN(torch.autograd.Function):
  47. @staticmethod
  48. def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
  49. return_dmask=False):
  50. x0 = x0.contiguous()
  51. x1 = x1.contiguous() if x1 is not None else None
  52. gamma = gamma.contiguous()
  53. beta = beta.contiguous()
  54. rowscale = rowscale.contiguous() if rowscale is not None else None
  55. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  56. x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
  57. )
  58. ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
  59. ctx.dropout_p = dropout_p
  60. ctx.has_residual = x1 is not None
  61. if not return_dmask:
  62. return zmat.view(x0.shape)
  63. else:
  64. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  65. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  66. ctx.mark_non_differentiable(dmask)
  67. return zmat.view(x0.shape), dmask
  68. @staticmethod
  69. def backward(ctx, dz, *args):
  70. # assert dz.is_contiguous()
  71. dz = dz.contiguous() # this happens!
  72. x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
  73. dropout_p = ctx.dropout_p
  74. has_residual = ctx.has_residual
  75. dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
  76. dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  77. )
  78. dx0 = dx0mat.view(x.shape)
  79. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  80. return dx0, dx1, dgamma, dbeta, None, None, None, None, None
  81. class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
  82. @staticmethod
  83. def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
  84. return_dmask=False):
  85. x0 = x0.contiguous()
  86. x1 = x1.contiguous() if x1 is not None else None
  87. gamma = gamma.contiguous()
  88. beta = beta.contiguous()
  89. rowscale = rowscale.contiguous() if rowscale is not None else None
  90. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  91. x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
  92. )
  93. ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
  94. ctx.dropout_p = dropout_p
  95. ctx.has_residual = x1 is not None
  96. if not return_dmask:
  97. return zmat.view(x0.shape), xmat.view(x0.shape)
  98. else:
  99. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  100. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  101. ctx.mark_non_differentiable(dmask)
  102. return zmat.view(x0.shape), xmat.view(x0.shape), dmask
  103. @staticmethod
  104. def backward(ctx, dz, dx, *args):
  105. # assert dz.is_contiguous()
  106. dz = dz.contiguous() # this happens!
  107. dx = dx.contiguous() # this happens!
  108. x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
  109. dropout_p = ctx.dropout_p
  110. has_residual = ctx.has_residual
  111. dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
  112. dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
  113. )
  114. dx0 = dx0mat.view(x.shape)
  115. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  116. return dx0, dx1, dgamma, dbeta, None, None, None, None, None
  117. def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
  118. prenorm=False, residual_in_fp32=False,
  119. return_dropout_mask=False):
  120. """residual_in_fp32 only has an effect if x1 is None.
  121. Otherwise residual dtype is x1.dtype.
  122. """
  123. args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
  124. return_dropout_mask)
  125. if not prenorm:
  126. return DropoutAddLayerNormFN.apply(*args)
  127. else:
  128. return DropoutAddLayerNormPrenormFN.apply(*args)
  129. class DropoutAddLayerNorm(torch.nn.Module):
  130. def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
  131. device=None, dtype=None):
  132. factory_kwargs = {'device': device, 'dtype': dtype}
  133. super().__init__()
  134. self.prenorm = prenorm
  135. self.p = p
  136. self.epsilon = eps
  137. self.residual_in_fp32 = residual_in_fp32
  138. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  139. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  140. self.reset_parameters()
  141. def reset_parameters(self):
  142. init.ones_(self.weight)
  143. init.zeros_(self.bias)
  144. def forward(self, x0, x1=None):
  145. return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
  146. self.p if self.training else 0.0, self.epsilon,
  147. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)