layer_norm.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import torch
  4. from torch.nn import init
  5. import dropout_layer_norm
  6. def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
  7. residual_in_fp32):
  8. """ Assume that arguments are contiguous
  9. """
  10. hidden_size = gamma.numel()
  11. x0mat = x0.view((-1, hidden_size))
  12. x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
  13. rowscale = rowscale.view(-1) if rowscale is not None else None
  14. zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
  15. x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
  16. )
  17. # dmask is None if dropout_p == 0.0
  18. # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
  19. return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
  20. def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
  21. dropout_p, has_residual):
  22. """ Assume that arguments are contiguous
  23. dx == None means that it was a post-norm architecture
  24. (x = drop(x0) + x1 was not returned in the fwd).
  25. x0 must not be None if we have colscale.
  26. """
  27. hidden_size = gamma.numel()
  28. xmat = x.view((-1, hidden_size))
  29. dzmat = dz.view(xmat.shape)
  30. dxmat = dx.view(xmat.shape) if dx is not None else None
  31. x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
  32. rowscale = rowscale.view(-1) if rowscale is not None else None
  33. colscale = colscale.view(-1) if colscale is not None else None
  34. if colscale is not None:
  35. assert x0 is not None, 'x0 is required to compute the gradient of colscale'
  36. dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
  37. dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
  38. has_residual
  39. )
  40. # dx1mat is None if not has_residual
  41. if colscale is None:
  42. return dx0mat, dx1mat, dgamma, dbeta
  43. else:
  44. dcolscale = rest[0]
  45. return dx0mat, dx1mat, dgamma, dbeta, dcolscale
  46. class DropoutAddLayerNormFn(torch.autograd.Function):
  47. @staticmethod
  48. def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
  49. prenorm=False, return_dmask=False):
  50. x0 = x0.contiguous()
  51. x1 = x1.contiguous() if x1 is not None else None
  52. gamma = gamma.contiguous()
  53. beta = beta.contiguous()
  54. rowscale = rowscale.contiguous() if rowscale is not None else None
  55. colscale = colscale.contiguous() if colscale is not None else None
  56. zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
  57. x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
  58. )
  59. # Only need to save x0 if we need to compute gradient wrt colscale
  60. x0_saved = x0 if colscale is not None else None
  61. ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
  62. ctx.prenorm = prenorm
  63. ctx.dropout_p = dropout_p
  64. ctx.has_residual = x1 is not None
  65. if not return_dmask:
  66. return (zmat.view(x0.shape) if not prenorm
  67. else (zmat.view(x0.shape), xmat.view(x0.shape)))
  68. else:
  69. dmask = (dmask.view(x0.shape) if dropout_p > 0.
  70. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
  71. ctx.mark_non_differentiable(dmask)
  72. return ((zmat.view(x0.shape), dmask) if not prenorm
  73. else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
  74. @staticmethod
  75. def backward(ctx, dz, *args):
  76. # assert dz.is_contiguous()
  77. dz = dz.contiguous() # this happens!
  78. dx = args[0].contiguous() if ctx.prenorm else None
  79. x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
  80. # x0 is None if colscale is None
  81. dropout_p = ctx.dropout_p
  82. has_residual = ctx.has_residual
  83. dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
  84. dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
  85. )
  86. dx0 = dx0mat.view(x.shape)
  87. dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
  88. dcolscale = rest[0] if colscale is not None else None
  89. return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
  90. def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
  91. prenorm=False, residual_in_fp32=False,
  92. return_dropout_mask=False):
  93. """residual_in_fp32 only has an effect if x1 is None.
  94. Otherwise residual dtype is x1.dtype.
  95. """
  96. return DropoutAddLayerNormFn.apply(
  97. x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
  98. return_dropout_mask
  99. )
  100. class DropoutAddLayerNorm(torch.nn.Module):
  101. def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
  102. device=None, dtype=None):
  103. factory_kwargs = {'device': device, 'dtype': dtype}
  104. super().__init__()
  105. self.prenorm = prenorm
  106. self.p = p
  107. self.epsilon = eps
  108. self.residual_in_fp32 = residual_in_fp32
  109. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  110. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  111. self.reset_parameters()
  112. def reset_parameters(self):
  113. init.ones_(self.weight)
  114. init.zeros_(self.bias)
  115. def forward(self, x0, x1=None):
  116. return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
  117. self.p if self.training else 0.0, self.epsilon,
  118. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)