123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- """Custom normalization layers."""
- from typing import Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from aphrodite.modeling._custom_op import CustomOp
- class LayerNorm(nn.LayerNorm):
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-6,
- ) -> None:
- super().__init__(hidden_size, eps=eps)
- def forward(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """normalization."""
- if residual is not None:
- x = x + residual
- residual = x
- x = super().forward(x)
- if residual is None:
- return x
- else:
- return x, residual
- class RMSNorm(CustomOp):
- """Root mean square normalization.
- Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
- Refer to https://arxiv.org/abs/1910.07467
- """
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-6,
- ) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward_native(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """PyTorch-native implementation equivalent to forward()."""
- orig_dtype = x.dtype
- x = x.to(torch.float32)
- if residual is not None:
- x = x + residual.to(torch.float32)
- residual = x.to(orig_dtype)
- variance = x.pow(2).mean(dim=-1, keepdim=True)
- x = x * torch.rsqrt(variance + self.variance_epsilon)
- x = x.to(orig_dtype) * self.weight
- if residual is None:
- return x
- else:
- return x, residual
- def forward_cuda(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- from aphrodite import _custom_ops as ops
- if residual is not None:
- ops.fused_add_rms_norm(
- x,
- residual,
- self.weight.data,
- self.variance_epsilon,
- )
- return x, residual
- out = torch.empty_like(x)
- ops.rms_norm(
- out,
- x,
- self.weight.data,
- self.variance_epsilon,
- )
- return out
- def forward_xpu(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- from aphrodite._ipex_ops import ipex_ops as ops
- if residual is not None:
- ops.fused_add_rms_norm(
- x,
- residual,
- self.weight.data,
- self.variance_epsilon,
- )
- return x, residual
- out = torch.empty_like(x)
- ops.rms_norm(
- out,
- x,
- self.weight.data,
- self.variance_epsilon,
- )
- return out
- def extra_repr(self) -> str:
- s = f"hidden_size={self.weight.data.size(0)}"
- s += f", eps={self.variance_epsilon}"
- return s
- class GemmaRMSNorm(CustomOp):
- """RMS normalization for Gemma.
- Two differences from the above RMSNorm:
- 1. x * (1 + w) instead of x * w.
- 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
- """
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-6,
- ) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
- @staticmethod
- def forward_static(
- weight: torch.Tensor,
- variance_epsilon: float,
- x: torch.Tensor,
- residual: Optional[torch.Tensor],
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """PyTorch-native implementation equivalent to forward()."""
- orig_dtype = x.dtype
- if residual is not None:
- x = x + residual
- residual = x
- x = x.float()
- variance = x.pow(2).mean(dim=-1, keepdim=True)
- x = x * torch.rsqrt(variance + variance_epsilon)
- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
- # See https://github.com/huggingface/transformers/pull/29402
- x = x * (1.0 + weight.float())
- x = x.to(orig_dtype)
- return x if residual is None else (x, residual)
- def forward_native(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """PyTorch-native implementation equivalent to forward()."""
- return self.forward_static(self.weight.data, self.variance_epsilon, x,
- residual)
- def forward_cuda(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- if torch.compiler.is_compiling():
- return self.forward_native(x, residual)
- if not getattr(self, "_is_compiled", False):
- self.forward_static = torch.compile( # type: ignore
- self.forward_static)
- self._is_compiled = True
- return self.forward_native(x, residual)
|