block.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # Copyright (c) 2022, Tri Dao.
  2. from typing import Optional
  3. from functools import partial
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torchvision.ops import StochasticDepth
  9. from flash_attn.modules.mha import MHA
  10. from flash_attn.modules.mlp import Mlp
  11. try:
  12. from flash_attn.ops.layer_norm import dropout_add_layer_norm
  13. except ImportError:
  14. dropout_add_layer_norm = None
  15. try:
  16. from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
  17. except ImportError:
  18. dropout_add_layer_norm_parallel_residual = None
  19. class Block(nn.Module):
  20. def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
  21. dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
  22. drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
  23. residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
  24. """
  25. For prenorm=True, this Block has a slightly different structure compared to a regular
  26. prenorm Transformer block.
  27. The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
  28. [Ref: https://arxiv.org/abs/2002.04745]
  29. Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
  30. the hidden_states (output of the MLP) and the residual.
  31. This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
  32. The residual needs to be provided (except for the very first block).
  33. For prenorm=False, this Block has the same structure as a regular postnorm Transformer
  34. block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
  35. return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
  36. This is for performance reason: for post-norm architecture, returning the input allows us
  37. to fuse the backward of nn.Linear with the residual connection.
  38. """
  39. super().__init__()
  40. self.prenorm = prenorm
  41. self.fused_dropout_add_ln = fused_dropout_add_ln
  42. self.return_residual = return_residual
  43. self.residual_in_fp32 = residual_in_fp32
  44. if self.residual_in_fp32:
  45. assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
  46. if mixer_cls is None:
  47. mixer_cls = partial(MHA, num_heads=dim // 64)
  48. if mlp_cls is None:
  49. mlp_cls = partial(Mlp, hidden_features=4 * dim)
  50. self.mixer = mixer_cls(dim)
  51. self.dropout1 = dropout_cls(resid_dropout1)
  52. self.drop_path1 = StochasticDepth(drop_path1, mode='row')
  53. self.norm1 = norm_cls(dim)
  54. self.mlp = mlp_cls(dim)
  55. if not isinstance(self.mlp, nn.Identity):
  56. self.dropout2 = dropout_cls(resid_dropout2)
  57. self.drop_path2 = StochasticDepth(drop_path2, mode='row')
  58. self.norm2 = norm_cls(dim)
  59. if self.fused_dropout_add_ln:
  60. assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
  61. assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
  62. # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
  63. # then the input to each worker in the tensor parallel group will be different.
  64. # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
  65. # For now this is not an issue because we always use sequence_parallel=True during training
  66. # and only use sequence_parallel=False during inference.
  67. # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
  68. if sequence_parallel:
  69. for p in self.norm1.parameters():
  70. p._sequence_parallel = True
  71. if hasattr(self, 'norm2'):
  72. for p in self.norm2.parameters():
  73. p._sequence_parallel = True
  74. # Mark the norm parameters as "shared_params" so that we sync their values at init.
  75. if mark_shared_params:
  76. for p in self.norm1.parameters():
  77. p._shared_params = True
  78. if hasattr(self, 'norm2'):
  79. for p in self.norm2.parameters():
  80. p._shared_params = True
  81. def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
  82. mixer_subset=None, mixer_kwargs=None):
  83. r"""Pass the input through the encoder layer.
  84. Args:
  85. hidden_states: the sequence to the encoder layer (required).
  86. residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
  87. mixer_subset: for cross-attention only. If not None, will take a subset of x
  88. before applying the query projection. Useful for e.g., ViT where we only care
  89. about the CLS token in the last layer.
  90. """
  91. if self.prenorm:
  92. if not self.fused_dropout_add_ln:
  93. dropped = self.drop_path1(self.dropout1(hidden_states))
  94. residual = (dropped + residual) if residual is not None else dropped
  95. hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
  96. if self.residual_in_fp32:
  97. residual = residual.to(torch.float32)
  98. else:
  99. if self.drop_path1.p == 0 or not self.training:
  100. rowscale1 = None
  101. else:
  102. rowscale1 = self.drop_path1(torch.ones(
  103. hidden_states.shape[:-1], device=hidden_states.device,
  104. dtype=hidden_states.dtype)
  105. )
  106. hidden_states, residual = dropout_add_layer_norm(
  107. hidden_states, residual, self.norm1.weight, self.norm1.bias,
  108. self.dropout1.p if self.training else 0.0, self.norm1.eps,
  109. rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
  110. )
  111. if mixer_kwargs is None:
  112. mixer_kwargs = {}
  113. if mixer_subset is not None:
  114. mixer_kwargs['mixer_subset'] = mixer_subset
  115. hidden_states = self.mixer(hidden_states, **mixer_kwargs)
  116. if mixer_subset is not None:
  117. residual = residual[:, mixer_subset]
  118. if not isinstance(self.mlp, nn.Identity):
  119. if not self.fused_dropout_add_ln:
  120. dropped = self.drop_path2(self.dropout2(hidden_states))
  121. residual = (dropped + residual) if residual is not None else dropped
  122. hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
  123. if self.residual_in_fp32:
  124. residual = residual.to(torch.float32)
  125. else:
  126. if self.drop_path2.p == 0 or not self.training:
  127. rowscale2 = None
  128. else:
  129. rowscale2 = self.drop_path2(torch.ones(
  130. hidden_states.shape[:-1], device=hidden_states.device,
  131. dtype=hidden_states.dtype)
  132. )
  133. hidden_states, residual = dropout_add_layer_norm(
  134. hidden_states, residual, self.norm2.weight, self.norm2.bias,
  135. self.dropout2.p if self.training else 0.0, self.norm2.eps,
  136. rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
  137. )
  138. hidden_states = self.mlp(hidden_states)
  139. return hidden_states, residual
  140. else:
  141. assert residual is None
  142. mixer_out = self.mixer(
  143. hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
  144. )
  145. if self.return_residual: # mixer out is actually a pair here
  146. mixer_out, hidden_states = mixer_out
  147. if not self.fused_dropout_add_ln:
  148. hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
  149. + hidden_states).to(dtype=self.norm1.weight.dtype))
  150. else:
  151. if self.drop_path1.p == 0 or not self.training:
  152. rowscale1 = None
  153. else:
  154. rowscale1 = self.drop_path1(torch.ones(
  155. mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
  156. )
  157. hidden_states = dropout_add_layer_norm(
  158. mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
  159. self.dropout1.p if self.training else 0.0, self.norm1.eps,
  160. rowscale=rowscale1, prenorm=False
  161. )
  162. if not isinstance(self.mlp, nn.Identity):
  163. mlp_out = self.mlp(hidden_states)
  164. if self.return_residual: # mlp out is actually a pair here
  165. mlp_out, hidden_states = mlp_out
  166. if not self.fused_dropout_add_ln:
  167. hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
  168. + hidden_states).to(dtype=self.norm2.weight.dtype))
  169. else:
  170. if self.drop_path2.p == 0 or not self.training:
  171. rowscale2 = None
  172. else:
  173. rowscale2 = self.drop_path2(torch.ones(
  174. mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
  175. )
  176. hidden_states = dropout_add_layer_norm(
  177. mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
  178. self.dropout2.p if self.training else 0.0, self.norm2.eps,
  179. rowscale=rowscale2, prenorm=False
  180. )
  181. return hidden_states
  182. class ParallelBlock(nn.Module):
  183. """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
  184. and PaLM.
  185. """
  186. def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
  187. dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
  188. tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
  189. sequence_parallel=False, mark_shared_params=False):
  190. """
  191. This Block has a slightly different structure compared to a regular
  192. prenorm Transformer block.
  193. The standard block is: LN -> MHA / MLP -> Dropout -> Add.
  194. [Ref: https://arxiv.org/abs/2002.04745]
  195. Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
  196. the hidden_states (output1 of the MHA / MLP) and the residual.
  197. This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
  198. The residual needs to be provided (except for the very first block).
  199. """
  200. super().__init__()
  201. self.tied_norm = tied_norm
  202. self.fused_dropout_add_ln = fused_dropout_add_ln
  203. self.residual_in_fp32 = residual_in_fp32
  204. if mixer_cls is None:
  205. mixer_cls = partial(MHA, num_heads=dim // 64)
  206. if mlp_cls is None:
  207. mlp_cls = partial(Mlp, hidden_features=4 * dim)
  208. self.mixer = mixer_cls(dim)
  209. self.dropout1 = dropout_cls(resid_dropout1)
  210. self.norm1 = norm_cls(dim)
  211. self.mlp = mlp_cls(dim)
  212. self.dropout2 = dropout_cls(resid_dropout2)
  213. if not self.tied_norm:
  214. self.norm2 = norm_cls(dim)
  215. if self.fused_dropout_add_ln:
  216. assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
  217. assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
  218. # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
  219. # then the input to each worker in the tensor parallel group will be different.
  220. # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
  221. # For now this is not an issue because we always use sequence_parallel=True during training
  222. # and only use sequence_parallel=False during inference.
  223. # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
  224. if sequence_parallel:
  225. for p in self.norm1.parameters():
  226. p._sequence_parallel = True
  227. if hasattr(self, 'norm2'):
  228. for p in self.norm2.parameters():
  229. p._sequence_parallel = True
  230. # Mark the norm parameters as "shared_params" so that we sync their values at init.
  231. if mark_shared_params:
  232. for p in self.norm1.parameters():
  233. p._shared_params = True
  234. if hasattr(self, 'norm2'):
  235. for p in self.norm2.parameters():
  236. p._shared_params = True
  237. def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
  238. residual: Optional[Tensor] = None, mixer_kwargs=None):
  239. r"""Pass the input through the encoder layer.
  240. Args:
  241. hidden_states1: the output of the previous attention (mixer) or embedding layer.
  242. hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
  243. residual.
  244. """
  245. if not self.fused_dropout_add_ln:
  246. dropped1 = self.dropout1(hidden_states1)
  247. # For the very 1st block, we only want 1 dropout, not two different dropouts
  248. if hidden_states2 is not None:
  249. dropped2 = self.dropout2(hidden_states2)
  250. residual = ((residual + dropped1 + dropped2)
  251. if residual is not None else dropped1 + dropped2)
  252. else:
  253. residual = (residual + dropped1) if residual is not None else dropped1
  254. hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
  255. hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
  256. if not self.tied_norm else hidden_states1)
  257. if self.residual_in_fp32:
  258. residual = residual.to(torch.float32)
  259. else:
  260. weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
  261. if not self.tied_norm else (None, None))
  262. hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual(
  263. hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
  264. weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
  265. prenorm=True, residual_in_fp32=self.residual_in_fp32
  266. )
  267. if self.tied_norm:
  268. hidden_states2 = hidden_states1
  269. if mixer_kwargs is None:
  270. mixer_kwargs = {}
  271. hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
  272. hidden_states2 = self.mlp(hidden_states2)
  273. return hidden_states1, hidden_states2, residual