layernorm.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. """Custom normalization layers."""
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling._custom_op import CustomOp
  6. class LayerNorm(nn.LayerNorm):
  7. def __init__(
  8. self,
  9. hidden_size: int,
  10. eps: float = 1e-6,
  11. ) -> None:
  12. super().__init__(hidden_size, eps=eps)
  13. def forward(
  14. self,
  15. x: torch.Tensor,
  16. residual: Optional[torch.Tensor] = None,
  17. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  18. """normalization."""
  19. if residual is not None:
  20. x = x + residual
  21. residual = x
  22. x = super().forward(x)
  23. if residual is None:
  24. return x
  25. else:
  26. return x, residual
  27. class RMSNorm(CustomOp):
  28. """Root mean square normalization.
  29. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
  30. Refer to https://arxiv.org/abs/1910.07467
  31. """
  32. def __init__(
  33. self,
  34. hidden_size: int,
  35. eps: float = 1e-6,
  36. ) -> None:
  37. super().__init__()
  38. self.weight = nn.Parameter(torch.ones(hidden_size))
  39. self.variance_epsilon = eps
  40. def forward_native(
  41. self,
  42. x: torch.Tensor,
  43. residual: Optional[torch.Tensor] = None,
  44. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  45. """PyTorch-native implementation equivalent to forward()."""
  46. orig_dtype = x.dtype
  47. x = x.to(torch.float32)
  48. if residual is not None:
  49. x = x + residual.to(torch.float32)
  50. residual = x.to(orig_dtype)
  51. variance = x.pow(2).mean(dim=-1, keepdim=True)
  52. x = x * torch.rsqrt(variance + self.variance_epsilon)
  53. x = x.to(orig_dtype) * self.weight
  54. if residual is None:
  55. return x
  56. else:
  57. return x, residual
  58. def forward_cuda(
  59. self,
  60. x: torch.Tensor,
  61. residual: Optional[torch.Tensor] = None,
  62. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  63. from aphrodite import _custom_ops as ops
  64. if residual is not None:
  65. ops.fused_add_rms_norm(
  66. x,
  67. residual,
  68. self.weight.data,
  69. self.variance_epsilon,
  70. )
  71. return x, residual
  72. out = torch.empty_like(x)
  73. ops.rms_norm(
  74. out,
  75. x,
  76. self.weight.data,
  77. self.variance_epsilon,
  78. )
  79. return out
  80. def forward_xpu(
  81. self,
  82. x: torch.Tensor,
  83. residual: Optional[torch.Tensor] = None,
  84. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  85. from aphrodite._ipex_ops import ipex_ops as ops
  86. if residual is not None:
  87. ops.fused_add_rms_norm(
  88. x,
  89. residual,
  90. self.weight.data,
  91. self.variance_epsilon,
  92. )
  93. return x, residual
  94. out = torch.empty_like(x)
  95. ops.rms_norm(
  96. out,
  97. x,
  98. self.weight.data,
  99. self.variance_epsilon,
  100. )
  101. return out
  102. def extra_repr(self) -> str:
  103. s = f"hidden_size={self.weight.data.size(0)}"
  104. s += f", eps={self.variance_epsilon}"
  105. return s
  106. class GemmaRMSNorm(CustomOp):
  107. """RMS normalization for Gemma.
  108. Two differences from the above RMSNorm:
  109. 1. x * (1 + w) instead of x * w.
  110. 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
  111. """
  112. def __init__(
  113. self,
  114. hidden_size: int,
  115. eps: float = 1e-6,
  116. ) -> None:
  117. super().__init__()
  118. self.weight = nn.Parameter(torch.zeros(hidden_size))
  119. self.variance_epsilon = eps
  120. def forward_native(
  121. self,
  122. x: torch.Tensor,
  123. residual: Optional[torch.Tensor] = None,
  124. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  125. """PyTorch-native implementation equivalent to forward()."""
  126. orig_dtype = x.dtype
  127. if residual is not None:
  128. x = x + residual
  129. residual = x
  130. x = x.float()
  131. variance = x.pow(2).mean(dim=-1, keepdim=True)
  132. x = x * torch.rsqrt(variance + self.variance_epsilon)
  133. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  134. # See https://github.com/huggingface/transformers/pull/29402
  135. x = x * (1.0 + self.weight.float())
  136. x = x.to(orig_dtype)
  137. return x if residual is None else (x, residual)
  138. def forward_cuda(
  139. self,
  140. x: torch.Tensor,
  141. residual: Optional[torch.Tensor] = None,
  142. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  143. # TODO: Implement an optimized kernel for GemmaRMSNorm.
  144. return self.forward_native(x, residual)