# Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py import torch from torch.nn import init from flash_attn.ops.layer_norm import ( DropoutAddLayerNormFn, DropoutAddLayerNormParallelResidualFn, DropoutAddLayerNormSubsetFn, ) def rms_norm(x, weight, epsilon): return DropoutAddLayerNormFn.apply( x, None, weight, None, None, None, 0.0, epsilon, False, False, True ) def dropout_add_rms_norm( x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, True, return_dropout_mask, ) def dropout_add_rms_norm_subset( x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask, ) def dropout_add_rms_norm_parallel_residual( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormParallelResidualFn.apply( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, True, return_dropout_mask, ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) def forward(self, x): return rms_norm(x, self.weight, self.eps) class DropoutAddRMSNorm(torch.nn.Module): def __init__( self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.prenorm = prenorm self.p = p self.eps = eps self.residual_in_fp32 = residual_in_fp32 self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) def forward(self, x0, residual=None): return dropout_add_rms_norm( x0, residual, self.weight, None, self.p if self.training else 0.0, self.eps, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32, )