mlp.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2023, Tri Dao.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.distributed import ProcessGroup
  6. try:
  7. from flash_attn.ops.activations import swiglu
  8. except ImportError:
  9. swiglu = None
  10. try:
  11. from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
  12. except ImportError:
  13. ColumnParallelLinear, RowParallelLinear = None, None
  14. try:
  15. from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
  16. except ImportError:
  17. FusedMLP, ParallelFusedMLP = None, None
  18. class Mlp(nn.Module):
  19. def __init__(
  20. self,
  21. in_features,
  22. hidden_features=None,
  23. out_features=None,
  24. activation=F.gelu,
  25. bias1=True,
  26. bias2=True,
  27. return_residual=False,
  28. device=None,
  29. dtype=None,
  30. ):
  31. factory_kwargs = {"device": device, "dtype": dtype}
  32. super().__init__()
  33. out_features = out_features if out_features is not None else in_features
  34. hidden_features = hidden_features if hidden_features is not None else in_features * 4
  35. self.return_residual = return_residual
  36. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
  37. self.activation = activation
  38. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
  39. def forward(self, x):
  40. y = self.fc1(x)
  41. y = self.activation(y)
  42. y = self.fc2(y)
  43. return y if not self.return_residual else (y, x)
  44. class ParallelMLP(nn.Module):
  45. def __init__(
  46. self,
  47. in_features,
  48. hidden_features=None,
  49. out_features=None,
  50. activation=F.gelu,
  51. process_group: ProcessGroup = None,
  52. sequence_parallel=True,
  53. bias1=True,
  54. bias2=True,
  55. device=None,
  56. dtype=None,
  57. ):
  58. factory_kwargs = {"device": device, "dtype": dtype}
  59. super().__init__()
  60. assert ColumnParallelLinear is not None, "Need to install fused_dense"
  61. assert RowParallelLinear is not None, "Need to install fused_dense"
  62. out_features = out_features if out_features is not None else in_features
  63. hidden_features = hidden_features if hidden_features is not None else in_features * 4
  64. self.fc1 = ColumnParallelLinear(
  65. in_features,
  66. hidden_features,
  67. process_group,
  68. bias=bias1,
  69. sequence_parallel=sequence_parallel,
  70. **factory_kwargs,
  71. )
  72. self.activation = activation
  73. self.fc2 = RowParallelLinear(
  74. hidden_features,
  75. out_features,
  76. process_group,
  77. bias=bias2,
  78. sequence_parallel=sequence_parallel,
  79. **factory_kwargs,
  80. )
  81. def forward(self, x):
  82. y = self.fc1(x)
  83. y = self.activation(y)
  84. y = self.fc2(y)
  85. return y
  86. class GatedMlp(nn.Module):
  87. def __init__(
  88. self,
  89. in_features,
  90. hidden_features=None,
  91. out_features=None,
  92. activation=F.sigmoid,
  93. bias1=True,
  94. bias2=True,
  95. multiple_of=128,
  96. return_residual=False,
  97. device=None,
  98. dtype=None,
  99. ):
  100. factory_kwargs = {"device": device, "dtype": dtype}
  101. super().__init__()
  102. out_features = out_features if out_features is not None else in_features
  103. hidden_features = (
  104. hidden_features if hidden_features is not None else int(8 * in_features / 3)
  105. )
  106. hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
  107. self.return_residual = return_residual
  108. self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
  109. self.activation = activation
  110. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
  111. def forward(self, x):
  112. y = self.fc1(x)
  113. if self.activation == F.sigmoid: # Special case for GLU
  114. y = F.glu(y, dim=-1)
  115. elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
  116. y, gate = y.chunk(2, dim=-1)
  117. y = swiglu(gate, y)
  118. else:
  119. y, gate = y.chunk(2, dim=-1)
  120. y = y * self.activation(gate)
  121. y = self.fc2(y)
  122. return y if not self.return_residual else (y, x)
  123. class ParallelGatedMlp(nn.Module):
  124. """Parallel GatedMlp"""
  125. def __init__(
  126. self,
  127. in_features,
  128. process_group,
  129. hidden_features=None,
  130. out_features=None,
  131. activation=F.sigmoid,
  132. bias1=True,
  133. bias2=True,
  134. multiple_of=128,
  135. sequence_parallel=True,
  136. device=None,
  137. dtype=None,
  138. ):
  139. factory_kwargs = {"device": device, "dtype": dtype}
  140. super().__init__()
  141. out_features = out_features if out_features is not None else in_features
  142. hidden_features = (
  143. hidden_features if hidden_features is not None else int(8 * in_features / 3)
  144. )
  145. hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
  146. if ColumnParallelLinear is None or RowParallelLinear is None:
  147. raise ImportError("fused_dense is not installed")
  148. self.fc1 = ColumnParallelLinear(
  149. in_features,
  150. 2 * hidden_features,
  151. process_group,
  152. bias=bias1,
  153. sequence_parallel=sequence_parallel,
  154. **factory_kwargs,
  155. )
  156. self.activation = activation
  157. self.fc2 = RowParallelLinear(
  158. hidden_features,
  159. out_features,
  160. process_group,
  161. bias=bias2,
  162. sequence_parallel=sequence_parallel,
  163. **factory_kwargs,
  164. )
  165. def forward(self, x):
  166. y = self.fc1(x)
  167. if self.activation == F.sigmoid: # Special case for GLU
  168. y = F.glu(y, dim=-1)
  169. else:
  170. y, gate = y.chunk(2, dim=-1)
  171. y = y * self.activation(gate)
  172. y = self.fc2(y)
  173. return y