layernorm.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """Custom normalization layers."""
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling._custom_op import CustomOp
  6. class LayerNorm(nn.LayerNorm):
  7. def __init__(
  8. self,
  9. hidden_size: int,
  10. eps: float = 1e-6,
  11. ) -> None:
  12. super().__init__(hidden_size, eps=eps)
  13. def forward(
  14. self,
  15. x: torch.Tensor,
  16. residual: Optional[torch.Tensor] = None,
  17. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  18. """normalization."""
  19. if residual is not None:
  20. x = x + residual
  21. residual = x
  22. x = super().forward(x)
  23. if residual is None:
  24. return x
  25. else:
  26. return x, residual
  27. class RMSNorm(CustomOp):
  28. """Root mean square normalization.
  29. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
  30. Refer to https://arxiv.org/abs/1910.07467
  31. """
  32. def __init__(
  33. self,
  34. hidden_size: int,
  35. eps: float = 1e-6,
  36. ) -> None:
  37. super().__init__()
  38. self.weight = nn.Parameter(torch.ones(hidden_size))
  39. self.variance_epsilon = eps
  40. def forward_native(
  41. self,
  42. x: torch.Tensor,
  43. residual: Optional[torch.Tensor] = None,
  44. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  45. """PyTorch-native implementation equivalent to forward()."""
  46. orig_dtype = x.dtype
  47. x = x.to(torch.float32)
  48. if residual is not None:
  49. x = x + residual.to(torch.float32)
  50. residual = x.to(orig_dtype)
  51. variance = x.pow(2).mean(dim=-1, keepdim=True)
  52. x = x * torch.rsqrt(variance + self.variance_epsilon)
  53. x = x.to(orig_dtype) * self.weight
  54. if residual is None:
  55. return x
  56. else:
  57. return x, residual
  58. def forward_cuda(
  59. self,
  60. x: torch.Tensor,
  61. residual: Optional[torch.Tensor] = None,
  62. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  63. from aphrodite import _custom_ops as ops
  64. if residual is not None:
  65. ops.fused_add_rms_norm(
  66. x,
  67. residual,
  68. self.weight.data,
  69. self.variance_epsilon,
  70. )
  71. return x, residual
  72. out = torch.empty_like(x)
  73. ops.rms_norm(
  74. out,
  75. x,
  76. self.weight.data,
  77. self.variance_epsilon,
  78. )
  79. return out
  80. def forward_xpu(
  81. self,
  82. x: torch.Tensor,
  83. residual: Optional[torch.Tensor] = None,
  84. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  85. from aphrodite._ipex_ops import ipex_ops as ops
  86. if residual is not None:
  87. ops.fused_add_rms_norm(
  88. x,
  89. residual,
  90. self.weight.data,
  91. self.variance_epsilon,
  92. )
  93. return x, residual
  94. return ops.rms_norm(
  95. x,
  96. self.weight.data,
  97. self.variance_epsilon,
  98. )
  99. def extra_repr(self) -> str:
  100. s = f"hidden_size={self.weight.data.size(0)}"
  101. s += f", eps={self.variance_epsilon}"
  102. return s
  103. class GemmaRMSNorm(CustomOp):
  104. """RMS normalization for Gemma.
  105. Two differences from the above RMSNorm:
  106. 1. x * (1 + w) instead of x * w.
  107. 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
  108. """
  109. def __init__(
  110. self,
  111. hidden_size: int,
  112. eps: float = 1e-6,
  113. ) -> None:
  114. super().__init__()
  115. self.weight = nn.Parameter(torch.zeros(hidden_size))
  116. self.variance_epsilon = eps
  117. @staticmethod
  118. def forward_static(
  119. weight: torch.Tensor,
  120. variance_epsilon: float,
  121. x: torch.Tensor,
  122. residual: Optional[torch.Tensor],
  123. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  124. """PyTorch-native implementation equivalent to forward()."""
  125. orig_dtype = x.dtype
  126. if residual is not None:
  127. x = x + residual
  128. residual = x
  129. x = x.float()
  130. variance = x.pow(2).mean(dim=-1, keepdim=True)
  131. x = x * torch.rsqrt(variance + variance_epsilon)
  132. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  133. # See https://github.com/huggingface/transformers/pull/29402
  134. x = x * (1.0 + weight.float())
  135. x = x.to(orig_dtype)
  136. return x if residual is None else (x, residual)
  137. def forward_native(
  138. self,
  139. x: torch.Tensor,
  140. residual: Optional[torch.Tensor] = None,
  141. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  142. """PyTorch-native implementation equivalent to forward()."""
  143. return self.forward_static(self.weight.data, self.variance_epsilon, x,
  144. residual)
  145. def forward_cuda(
  146. self,
  147. x: torch.Tensor,
  148. residual: Optional[torch.Tensor] = None,
  149. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  150. if torch.compiler.is_compiling():
  151. return self.forward_native(x, residual)
  152. if not getattr(self, "_is_compiled", False):
  153. self.forward_static = torch.compile( # type: ignore
  154. self.forward_static)
  155. self._is_compiled = True
  156. return self.forward_native(x, residual)