12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- """Custom normalization layers."""
- from typing import Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from aphrodite._C import ops
- class LayerNorm(nn.LayerNorm):
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-6,
- ) -> None:
- super().__init__(hidden_size, eps=eps)
- def forward(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """normalization."""
- if residual is not None:
- x = x + residual
- residual = x
- x = super().forward(x)
- if residual is None:
- return x
- else:
- return x, residual
- class RMSNorm(nn.Module):
- """Root mean square normalization.
- Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
- Refer to https://arxiv.org/abs/1910.07467
- """
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-6,
- ) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def _forward(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """PyTorch-native implementation equivalent to forward()."""
- orig_dtype = x.dtype
- x = x.to(torch.float32)
- if residual is not None:
- x = x + residual.to(torch.float32)
- residual = x.to(orig_dtype)
- variance = x.pow(2).mean(dim=-1, keepdim=True)
- x = x * torch.rsqrt(variance + self.variance_epsilon)
- x = x.to(orig_dtype) * self.weight
- if residual is None:
- return x
- else:
- return x, residual
- def forward(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- if residual is not None:
- ops.fused_add_rms_norm(
- x,
- residual,
- self.weight.data,
- self.variance_epsilon,
- )
- return x, residual
- out = torch.empty_like(x)
- ops.rms_norm(
- out,
- x,
- self.weight.data,
- self.variance_epsilon,
- )
- return out
|