1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # Copyright (c) 2022, Tri Dao.
- # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
- import torch
- from torch.nn import init
- from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
- def rms_norm(x, weight, epsilon):
- return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
- False, True)
- def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
- layerscale=None, prenorm=False, residual_in_fp32=False,
- return_dropout_mask=False):
- """residual_in_fp32 only has an effect if residual is None.
- Otherwise residual dtype is residual.dtype.
- """
- return DropoutAddLayerNormFn.apply(
- x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
- True, return_dropout_mask
- )
- def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
- x0_subset=None, out_subset=None, rowscale_const=1.0,
- out_numrows=0, prenorm=False, residual_in_fp32=False,
- return_dropout_mask=False):
- """residual_in_fp32 only has an effect if residual is None.
- Otherwise residual dtype is residual.dtype.
- """
- return DropoutAddLayerNormSubsetFn.apply(
- x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
- rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
- )
- class DropoutAddRMSNorm(torch.nn.Module):
- def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
- device=None, dtype=None):
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.prenorm = prenorm
- self.p = p
- self.epsilon = eps
- self.residual_in_fp32 = residual_in_fp32
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self):
- init.ones_(self.weight)
- def forward(self, x0, residual=None):
- return dropout_add_rms_norm(x0, residual, self.weight, None,
- self.p if self.training else 0.0, self.epsilon,
- prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|