layernorm.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. """Custom normalization layers."""
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling._custom_op import CustomOp
  6. class LayerNorm(nn.LayerNorm):
  7. def __init__(
  8. self,
  9. hidden_size: int,
  10. eps: float = 1e-6,
  11. ) -> None:
  12. super().__init__(hidden_size, eps=eps)
  13. def forward(
  14. self,
  15. x: torch.Tensor,
  16. residual: Optional[torch.Tensor] = None,
  17. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  18. """normalization."""
  19. if residual is not None:
  20. x = x + residual
  21. residual = x
  22. x = super().forward(x)
  23. if residual is None:
  24. return x
  25. else:
  26. return x, residual
  27. class RMSNorm(CustomOp):
  28. """Root mean square normalization.
  29. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
  30. Refer to https://arxiv.org/abs/1910.07467
  31. """
  32. def __init__(
  33. self,
  34. hidden_size: int,
  35. eps: float = 1e-6,
  36. ) -> None:
  37. super().__init__()
  38. self.weight = nn.Parameter(torch.ones(hidden_size))
  39. self.variance_epsilon = eps
  40. def forward_native(
  41. self,
  42. x: torch.Tensor,
  43. residual: Optional[torch.Tensor] = None,
  44. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  45. """PyTorch-native implementation equivalent to forward()."""
  46. orig_dtype = x.dtype
  47. x = x.to(torch.float32)
  48. if residual is not None:
  49. x = x + residual.to(torch.float32)
  50. residual = x.to(orig_dtype)
  51. variance = x.pow(2).mean(dim=-1, keepdim=True)
  52. x = x * torch.rsqrt(variance + self.variance_epsilon)
  53. x = x.to(orig_dtype) * self.weight
  54. if residual is None:
  55. return x
  56. else:
  57. return x, residual
  58. def forward_cuda(
  59. self,
  60. x: torch.Tensor,
  61. residual: Optional[torch.Tensor] = None,
  62. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  63. from aphrodite import _custom_ops as ops
  64. if residual is not None:
  65. ops.fused_add_rms_norm(
  66. x,
  67. residual,
  68. self.weight.data,
  69. self.variance_epsilon,
  70. )
  71. return x, residual
  72. out = torch.empty_like(x)
  73. ops.rms_norm(
  74. out,
  75. x,
  76. self.weight.data,
  77. self.variance_epsilon,
  78. )
  79. return out
  80. def forward_xpu(
  81. self,
  82. x: torch.Tensor,
  83. residual: Optional[torch.Tensor] = None,
  84. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  85. from aphrodite._ipex_ops import ipex_ops as ops
  86. if residual is not None:
  87. ops.fused_add_rms_norm(
  88. x,
  89. residual,
  90. self.weight.data,
  91. self.variance_epsilon,
  92. )
  93. return x, residual
  94. return ops.rms_norm(
  95. x,
  96. self.weight.data,
  97. self.variance_epsilon,
  98. )
  99. def forward_triton(
  100. self,
  101. x: torch.Tensor,
  102. residual: Optional[torch.Tensor] = None,
  103. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  104. from aphrodite.modeling.layers.ops.layernorm import fast_rms_layernorm
  105. if residual is not None:
  106. x = x + residual
  107. return fast_rms_layernorm(self, x, gemma=False), x
  108. return fast_rms_layernorm(self, x, gemma=False)
  109. def extra_repr(self) -> str:
  110. s = f"hidden_size={self.weight.data.size(0)}"
  111. s += f", eps={self.variance_epsilon}"
  112. return s
  113. class GemmaRMSNorm(CustomOp):
  114. """RMS normalization for Gemma.
  115. Two differences from the above RMSNorm:
  116. 1. x * (1 + w) instead of x * w.
  117. 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
  118. """
  119. def __init__(
  120. self,
  121. hidden_size: int,
  122. eps: float = 1e-6,
  123. ) -> None:
  124. super().__init__()
  125. self.weight = nn.Parameter(torch.zeros(hidden_size))
  126. self.variance_epsilon = eps
  127. @staticmethod
  128. def forward_static(
  129. weight: torch.Tensor,
  130. variance_epsilon: float,
  131. x: torch.Tensor,
  132. residual: Optional[torch.Tensor],
  133. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  134. """PyTorch-native implementation equivalent to forward()."""
  135. orig_dtype = x.dtype
  136. if residual is not None:
  137. x = x + residual
  138. residual = x
  139. x = x.float()
  140. variance = x.pow(2).mean(dim=-1, keepdim=True)
  141. x = x * torch.rsqrt(variance + variance_epsilon)
  142. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  143. # See https://github.com/huggingface/transformers/pull/29402
  144. x = x * (1.0 + weight.float())
  145. x = x.to(orig_dtype)
  146. return x if residual is None else (x, residual)
  147. def forward_native(
  148. self,
  149. x: torch.Tensor,
  150. residual: Optional[torch.Tensor] = None,
  151. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  152. """PyTorch-native implementation equivalent to forward()."""
  153. return self.forward_static(self.weight.data, self.variance_epsilon, x,
  154. residual)
  155. def forward_cuda(
  156. self,
  157. x: torch.Tensor,
  158. residual: Optional[torch.Tensor] = None,
  159. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  160. if torch.compiler.is_compiling():
  161. return self.forward_native(x, residual)
  162. if not getattr(self, "_is_compiled", False):
  163. self.forward_static = torch.compile( # type: ignore
  164. self.forward_static)
  165. self._is_compiled = True
  166. return self.forward_native(x, residual)
  167. def forward_triton(
  168. self,
  169. x: torch.Tensor,
  170. residual: Optional[torch.Tensor] = None,
  171. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  172. from aphrodite.modeling.layers.ops.layernorm import fast_rms_layernorm
  173. if residual is not None:
  174. x = x + residual
  175. return fast_rms_layernorm(self, x, gemma=True), x
  176. return fast_rms_layernorm(self, x, gemma=True)