rms_norm.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import torch
  4. from torch.nn import init
  5. from flash_attn.ops.layer_norm import (
  6. DropoutAddLayerNormFn,
  7. DropoutAddLayerNormParallelResidualFn,
  8. DropoutAddLayerNormSubsetFn,
  9. )
  10. def rms_norm(x, weight, epsilon):
  11. return DropoutAddLayerNormFn.apply(
  12. x, None, weight, None, None, None, 0.0, epsilon, False, False, True
  13. )
  14. def dropout_add_rms_norm(
  15. x0,
  16. residual,
  17. weight,
  18. bias,
  19. dropout_p,
  20. epsilon,
  21. rowscale=None,
  22. layerscale=None,
  23. prenorm=False,
  24. residual_in_fp32=False,
  25. return_dropout_mask=False,
  26. ):
  27. """residual_in_fp32 only has an effect if residual is None.
  28. Otherwise residual dtype is residual.dtype.
  29. """
  30. return DropoutAddLayerNormFn.apply(
  31. x0,
  32. residual,
  33. weight,
  34. bias,
  35. rowscale,
  36. layerscale,
  37. dropout_p,
  38. epsilon,
  39. residual_in_fp32,
  40. prenorm,
  41. True,
  42. return_dropout_mask,
  43. )
  44. def dropout_add_rms_norm_subset(
  45. x0,
  46. residual,
  47. weight,
  48. bias,
  49. dropout_p,
  50. epsilon,
  51. layerscale=None,
  52. x0_subset=None,
  53. out_subset=None,
  54. rowscale_const=1.0,
  55. out_numrows=0,
  56. prenorm=False,
  57. residual_in_fp32=False,
  58. return_dropout_mask=False,
  59. ):
  60. """residual_in_fp32 only has an effect if residual is None.
  61. Otherwise residual dtype is residual.dtype.
  62. """
  63. return DropoutAddLayerNormSubsetFn.apply(
  64. x0,
  65. residual,
  66. weight,
  67. bias,
  68. layerscale,
  69. x0_subset,
  70. out_subset,
  71. dropout_p,
  72. epsilon,
  73. rowscale_const,
  74. out_numrows,
  75. residual_in_fp32,
  76. prenorm,
  77. True,
  78. return_dropout_mask,
  79. )
  80. def dropout_add_rms_norm_parallel_residual(
  81. x0,
  82. x1,
  83. residual,
  84. weight0,
  85. bias0,
  86. weight1,
  87. bias1,
  88. dropout_p,
  89. epsilon,
  90. prenorm=False,
  91. residual_in_fp32=False,
  92. return_dropout_mask=False,
  93. ):
  94. """residual_in_fp32 only has an effect if residual is None.
  95. Otherwise residual dtype is residual.dtype.
  96. """
  97. return DropoutAddLayerNormParallelResidualFn.apply(
  98. x0,
  99. x1,
  100. residual,
  101. weight0,
  102. bias0,
  103. weight1,
  104. bias1,
  105. dropout_p,
  106. epsilon,
  107. residual_in_fp32,
  108. prenorm,
  109. True,
  110. return_dropout_mask,
  111. )
  112. class RMSNorm(torch.nn.Module):
  113. def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
  114. factory_kwargs = {"device": device, "dtype": dtype}
  115. super().__init__()
  116. self.eps = eps
  117. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  118. self.register_parameter("bias", None)
  119. self.reset_parameters()
  120. def reset_parameters(self):
  121. init.ones_(self.weight)
  122. def forward(self, x):
  123. return rms_norm(x, self.weight, self.eps)
  124. class DropoutAddRMSNorm(torch.nn.Module):
  125. def __init__(
  126. self,
  127. hidden_size,
  128. prenorm=False,
  129. p=0.0,
  130. eps=1e-5,
  131. residual_in_fp32=False,
  132. device=None,
  133. dtype=None,
  134. ):
  135. factory_kwargs = {"device": device, "dtype": dtype}
  136. super().__init__()
  137. self.prenorm = prenorm
  138. self.p = p
  139. self.eps = eps
  140. self.residual_in_fp32 = residual_in_fp32
  141. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  142. self.register_parameter("bias", None)
  143. self.reset_parameters()
  144. def reset_parameters(self):
  145. init.ones_(self.weight)
  146. def forward(self, x0, residual=None):
  147. return dropout_add_rms_norm(
  148. x0,
  149. residual,
  150. self.weight,
  151. None,
  152. self.p if self.training else 0.0,
  153. self.eps,
  154. prenorm=self.prenorm,
  155. residual_in_fp32=self.residual_in_fp32,
  156. )