layernorm.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """Custom normalization layers."""
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite._C import ops
  6. class RMSNorm(nn.Module):
  7. """Root mean square normalization.
  8. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
  9. Refer to https://arxiv.org/abs/1910.07467
  10. """
  11. def __init__(
  12. self,
  13. hidden_size: int,
  14. eps: float = 1e-6,
  15. ) -> None:
  16. super().__init__()
  17. self.weight = nn.Parameter(torch.ones(hidden_size))
  18. self.variance_epsilon = eps
  19. def _forward(
  20. self,
  21. x: torch.Tensor,
  22. residual: Optional[torch.Tensor] = None,
  23. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  24. """PyTorch-native implementation equivalent to forward()."""
  25. orig_dtype = x.dtype
  26. x = x.to(torch.float32)
  27. if residual is not None:
  28. x = x + residual.to(torch.float32)
  29. residual = x.to(orig_dtype)
  30. variance = x.pow(2).mean(dim=-1, keepdim=True)
  31. x = x * torch.rsqrt(variance + self.variance_epsilon)
  32. x = x.to(orig_dtype) * self.weight
  33. if residual is None:
  34. return x
  35. else:
  36. return x, residual
  37. def forward(
  38. self,
  39. x: torch.Tensor,
  40. residual: Optional[torch.Tensor] = None,
  41. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  42. if residual is not None:
  43. ops.fused_add_rms_norm(
  44. x,
  45. residual,
  46. self.weight.data,
  47. self.variance_epsilon,
  48. )
  49. return x, residual
  50. out = torch.empty_like(x)
  51. ops.rms_norm(
  52. out,
  53. x,
  54. self.weight.data,
  55. self.variance_epsilon,
  56. )
  57. return out