123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- # Copyright (c) 2022, Tri Dao.
- # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
- import torch
- from torch.nn import init
- from flash_attn.ops.layer_norm import (
- DropoutAddLayerNormFn,
- DropoutAddLayerNormParallelResidualFn,
- DropoutAddLayerNormSubsetFn,
- )
- def rms_norm(x, weight, epsilon):
- return DropoutAddLayerNormFn.apply(
- x, None, weight, None, None, None, 0.0, epsilon, False, False, True
- )
- def dropout_add_rms_norm(
- x0,
- residual,
- weight,
- bias,
- dropout_p,
- epsilon,
- rowscale=None,
- layerscale=None,
- prenorm=False,
- residual_in_fp32=False,
- return_dropout_mask=False,
- ):
- """residual_in_fp32 only has an effect if residual is None.
- Otherwise residual dtype is residual.dtype.
- """
- return DropoutAddLayerNormFn.apply(
- x0,
- residual,
- weight,
- bias,
- rowscale,
- layerscale,
- dropout_p,
- epsilon,
- residual_in_fp32,
- prenorm,
- True,
- return_dropout_mask,
- )
- def dropout_add_rms_norm_subset(
- x0,
- residual,
- weight,
- bias,
- dropout_p,
- epsilon,
- layerscale=None,
- x0_subset=None,
- out_subset=None,
- rowscale_const=1.0,
- out_numrows=0,
- prenorm=False,
- residual_in_fp32=False,
- return_dropout_mask=False,
- ):
- """residual_in_fp32 only has an effect if residual is None.
- Otherwise residual dtype is residual.dtype.
- """
- return DropoutAddLayerNormSubsetFn.apply(
- x0,
- residual,
- weight,
- bias,
- layerscale,
- x0_subset,
- out_subset,
- dropout_p,
- epsilon,
- rowscale_const,
- out_numrows,
- residual_in_fp32,
- prenorm,
- True,
- return_dropout_mask,
- )
- def dropout_add_rms_norm_parallel_residual(
- x0,
- x1,
- residual,
- weight0,
- bias0,
- weight1,
- bias1,
- dropout_p,
- epsilon,
- prenorm=False,
- residual_in_fp32=False,
- return_dropout_mask=False,
- ):
- """residual_in_fp32 only has an effect if residual is None.
- Otherwise residual dtype is residual.dtype.
- """
- return DropoutAddLayerNormParallelResidualFn.apply(
- x0,
- x1,
- residual,
- weight0,
- bias0,
- weight1,
- bias1,
- dropout_p,
- epsilon,
- residual_in_fp32,
- prenorm,
- True,
- return_dropout_mask,
- )
- class RMSNorm(torch.nn.Module):
- def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.eps = eps
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
- self.register_parameter("bias", None)
- self.reset_parameters()
- def reset_parameters(self):
- init.ones_(self.weight)
- def forward(self, x):
- return rms_norm(x, self.weight, self.eps)
- class DropoutAddRMSNorm(torch.nn.Module):
- def __init__(
- self,
- hidden_size,
- prenorm=False,
- p=0.0,
- eps=1e-5,
- residual_in_fp32=False,
- device=None,
- dtype=None,
- ):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.prenorm = prenorm
- self.p = p
- self.eps = eps
- self.residual_in_fp32 = residual_in_fp32
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
- self.register_parameter("bias", None)
- self.reset_parameters()
- def reset_parameters(self):
- init.ones_(self.weight)
- def forward(self, x0, residual=None):
- return dropout_add_rms_norm(
- x0,
- residual,
- self.weight,
- None,
- self.p if self.training else 0.0,
- self.eps,
- prenorm=self.prenorm,
- residual_in_fp32=self.residual_in_fp32,
- )
|