rms_norm.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) 2022, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
  3. import torch
  4. from torch.nn import init
  5. from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
  6. from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
  7. def rms_norm(x, weight, epsilon):
  8. return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
  9. False, True)
  10. def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
  11. layerscale=None, prenorm=False, residual_in_fp32=False,
  12. return_dropout_mask=False):
  13. """residual_in_fp32 only has an effect if residual is None.
  14. Otherwise residual dtype is residual.dtype.
  15. """
  16. return DropoutAddLayerNormFn.apply(
  17. x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
  18. True, return_dropout_mask
  19. )
  20. def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
  21. x0_subset=None, out_subset=None, rowscale_const=1.0,
  22. out_numrows=0, prenorm=False, residual_in_fp32=False,
  23. return_dropout_mask=False):
  24. """residual_in_fp32 only has an effect if residual is None.
  25. Otherwise residual dtype is residual.dtype.
  26. """
  27. return DropoutAddLayerNormSubsetFn.apply(
  28. x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
  29. rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
  30. )
  31. def dropout_add_rms_norm_parallel_residual(
  32. x0, x1, residual, weight0, bias0, weight1, bias1,
  33. dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
  34. ):
  35. """residual_in_fp32 only has an effect if residual is None.
  36. Otherwise residual dtype is residual.dtype.
  37. """
  38. return DropoutAddLayerNormParallelResidualFn.apply(
  39. x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
  40. True, return_dropout_mask
  41. )
  42. class DropoutAddRMSNorm(torch.nn.Module):
  43. def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
  44. device=None, dtype=None):
  45. factory_kwargs = {'device': device, 'dtype': dtype}
  46. super().__init__()
  47. self.prenorm = prenorm
  48. self.p = p
  49. self.epsilon = eps
  50. self.residual_in_fp32 = residual_in_fp32
  51. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  52. self.register_parameter('bias', None)
  53. self.reset_parameters()
  54. def reset_parameters(self):
  55. init.ones_(self.weight)
  56. def forward(self, x0, residual=None):
  57. return dropout_add_rms_norm(x0, residual, self.weight, None,
  58. self.p if self.training else 0.0, self.epsilon,
  59. prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)