layernorm.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """Custom normalization layers."""
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite._C import ops
  6. class LayerNorm(nn.LayerNorm):
  7. def __init__(
  8. self,
  9. hidden_size: int,
  10. eps: float = 1e-6,
  11. ) -> None:
  12. super().__init__(hidden_size, eps=eps)
  13. def forward(
  14. self,
  15. x: torch.Tensor,
  16. residual: Optional[torch.Tensor] = None,
  17. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  18. """normalization."""
  19. if residual is not None:
  20. x = x + residual
  21. residual = x
  22. x = super().forward(x)
  23. if residual is None:
  24. return x
  25. else:
  26. return x, residual
  27. class RMSNorm(nn.Module):
  28. """Root mean square normalization.
  29. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
  30. Refer to https://arxiv.org/abs/1910.07467
  31. """
  32. def __init__(
  33. self,
  34. hidden_size: int,
  35. eps: float = 1e-6,
  36. ) -> None:
  37. super().__init__()
  38. self.weight = nn.Parameter(torch.ones(hidden_size))
  39. self.variance_epsilon = eps
  40. def _forward(
  41. self,
  42. x: torch.Tensor,
  43. residual: Optional[torch.Tensor] = None,
  44. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  45. """PyTorch-native implementation equivalent to forward()."""
  46. orig_dtype = x.dtype
  47. x = x.to(torch.float32)
  48. if residual is not None:
  49. x = x + residual.to(torch.float32)
  50. residual = x.to(orig_dtype)
  51. variance = x.pow(2).mean(dim=-1, keepdim=True)
  52. x = x * torch.rsqrt(variance + self.variance_epsilon)
  53. x = x.to(orig_dtype) * self.weight
  54. if residual is None:
  55. return x
  56. else:
  57. return x, residual
  58. def forward(
  59. self,
  60. x: torch.Tensor,
  61. residual: Optional[torch.Tensor] = None,
  62. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  63. if residual is not None:
  64. ops.fused_add_rms_norm(
  65. x,
  66. residual,
  67. self.weight.data,
  68. self.variance_epsilon,
  69. )
  70. return x, residual
  71. out = torch.empty_like(x)
  72. ops.rms_norm(
  73. out,
  74. x,
  75. self.weight.data,
  76. self.variance_epsilon,
  77. )
  78. return out