activation.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size)
  9. from aphrodite.modeling._custom_op import CustomOp
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization import QuantizationConfig
  12. class SiluAndMul(CustomOp):
  13. """An activation function for SwiGLU.
  14. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  15. Shapes:
  16. x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
  17. return: (num_tokens, d) or (batch_size, seq_len, d)
  18. """
  19. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  20. """PyTorch-native implementation equivalent to forward()."""
  21. d = x.shape[-1] // 2
  22. return F.silu(x[..., :d]) * x[..., d:]
  23. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  24. from aphrodite import _custom_ops as ops
  25. d = x.shape[-1] // 2
  26. output_shape = (x.shape[:-1] + (d, ))
  27. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  28. ops.silu_and_mul(out, x)
  29. return out
  30. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  31. from aphrodite._ipex_ops import ipex_ops as ops
  32. d = x.shape[-1] // 2
  33. output_shape = (x.shape[:-1] + (d, ))
  34. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  35. ops.silu_and_mul(out, x)
  36. return out
  37. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  38. from aphrodite.modeling.layers.ops.activation import swiglu_fg_kernel
  39. d = x.shape[-1] // 2
  40. return swiglu_fg_kernel(x[..., :d], x[..., d:])
  41. class GeluAndMul(CustomOp):
  42. """An activation function for GeGLU.
  43. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  44. Shapes:
  45. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  46. return: (batch_size, seq_len, d) or (num_tokens, d)
  47. """
  48. def __init__(self, approximate: str = "none"):
  49. super().__init__()
  50. self.approximate = approximate
  51. if approximate not in ("none", "tanh"):
  52. raise ValueError(f"Unknown approximate mode: {approximate}")
  53. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  54. """PyTorch-native implementation equivalent to forward()."""
  55. d = x.shape[-1] // 2
  56. return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
  57. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  58. from aphrodite import _custom_ops as ops
  59. d = x.shape[-1] // 2
  60. output_shape = (x.shape[:-1] + (d, ))
  61. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  62. if self.approximate == "none":
  63. ops.gelu_and_mul(out, x)
  64. elif self.approximate == "tanh":
  65. ops.gelu_tanh_and_mul(out, x)
  66. return out
  67. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  68. from aphrodite._ipex_ops import ipex_ops as ops
  69. d = x.shape[-1] // 2
  70. output_shape = (x.shape[:-1] + (d, ))
  71. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  72. if self.approximate == "none":
  73. ops.gelu_and_mul(out, x)
  74. elif self.approximate == "tanh":
  75. ops.gelu_tanh_and_mul(out, x)
  76. return out
  77. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  78. from aphrodite.modeling.layers.ops.activation import (
  79. geglu_approx_forward_kernel, geglu_exact_forward_kernel)
  80. d = x.shape[-1] // 2
  81. gate = x[..., :d]
  82. up = x[..., d:]
  83. if self.approximate == "none":
  84. return geglu_exact_forward_kernel(gate, up)
  85. elif self.approximate == "tanh":
  86. return geglu_approx_forward_kernel(gate, up)
  87. def extra_repr(self) -> str:
  88. return f'approximate={repr(self.approximate)}'
  89. class NewGELU(CustomOp):
  90. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  91. """PyTorch-native implementation equivalent to forward()."""
  92. c = math.sqrt(2.0 / math.pi)
  93. return 0.5 * x * (1.0 + torch.tanh(c *
  94. (x + 0.044715 * torch.pow(x, 3.0))))
  95. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  96. from aphrodite import _custom_ops as ops
  97. out = torch.empty_like(x)
  98. ops.gelu_new(out, x)
  99. return out
  100. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  101. from aphrodite._ipex_ops import ipex_ops as ops
  102. return ops.gelu_new(x)
  103. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  104. from aphrodite.modeling.layers.ops.activation import gelu_new_kernel
  105. return gelu_new_kernel(x)
  106. class FastGELU(CustomOp):
  107. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  108. """PyTorch-native implementation equivalent to forward()."""
  109. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  110. (1.0 + 0.044715 * x * x)))
  111. def forward(self, x: torch.Tensor) -> torch.Tensor:
  112. from aphrodite import _custom_ops as ops
  113. out = torch.empty_like(x)
  114. ops.gelu_fast(out, x)
  115. return out
  116. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  117. from aphrodite._ipex_ops import ipex_ops as ops
  118. return ops.gelu_fast(x)
  119. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  120. from aphrodite.modeling.layers.ops.activation import fast_gelu_kernel
  121. return fast_gelu_kernel(x)
  122. class QuickGELU(CustomOp):
  123. # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
  124. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  125. """PyTorch-native implementation equivalent to forward()."""
  126. return x * torch.sigmoid(1.702 * x)
  127. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  128. from aphrodite import _custom_ops as ops
  129. out = torch.empty_like(x)
  130. ops.gelu_quick(out, x)
  131. return out
  132. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  133. from aphrodite._ipex_ops import ipex_ops as ops
  134. out = torch.empty_like(x)
  135. ops.gelu_quick(out, x)
  136. return out
  137. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  138. from aphrodite.modeling.layers.ops.activation import quick_gelu_kernel
  139. return quick_gelu_kernel(x)
  140. class ReLUSquaredActivation(CustomOp):
  141. """
  142. Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
  143. """
  144. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  145. """PyTorch-native implementation equivalent to forward()."""
  146. return torch.square(F.relu(x))
  147. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  148. return self.forward_native(x)
  149. def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
  150. from aphrodite.modeling.layers.ops.activation import relu_squared_kernel
  151. return relu_squared_kernel(x)
  152. class ScaledActivation(nn.Module):
  153. """An activation function with post-scale parameters.
  154. This is used for some quantization methods like AWQ.
  155. """
  156. def __init__(
  157. self,
  158. act_module: nn.Module,
  159. intermediate_size: int,
  160. input_is_parallel: bool = True,
  161. params_dtype: Optional[torch.dtype] = None,
  162. ):
  163. super().__init__()
  164. self.act = act_module
  165. self.input_is_parallel = input_is_parallel
  166. if input_is_parallel:
  167. tp_size = get_tensor_model_parallel_world_size()
  168. intermediate_size_per_partition = divide(intermediate_size,
  169. tp_size)
  170. else:
  171. intermediate_size_per_partition = intermediate_size
  172. if params_dtype is None:
  173. params_dtype = torch.get_default_dtype()
  174. self.scales = nn.Parameter(
  175. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  176. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  177. def forward(self, x: torch.Tensor) -> torch.Tensor:
  178. return self.act(x) / self.scales
  179. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  180. param_data = param.data
  181. if self.input_is_parallel:
  182. tp_rank = get_tensor_model_parallel_rank()
  183. shard_size = param_data.shape[0]
  184. start_idx = tp_rank * shard_size
  185. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  186. assert param_data.shape == loaded_weight.shape
  187. param_data.copy_(loaded_weight)
  188. _ACTIVATION_REGISTRY = {
  189. "gelu": nn.GELU(),
  190. "gelu_fast": FastGELU(),
  191. "gelu_new": NewGELU(),
  192. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  193. "relu": nn.ReLU(),
  194. "relu2": ReLUSquaredActivation(),
  195. "quick_gelu": QuickGELU(),
  196. }
  197. def get_act_fn(
  198. act_fn_name: str,
  199. quant_config: Optional[QuantizationConfig] = None,
  200. intermediate_size: Optional[int] = None,
  201. input_is_parallel: bool = True,
  202. params_dtype: Optional[torch.dtype] = None,
  203. ) -> nn.Module:
  204. """Get an activation function by name."""
  205. act_fn_name = act_fn_name.lower()
  206. if act_fn_name not in _ACTIVATION_REGISTRY:
  207. raise ValueError(
  208. f"Activation function {act_fn_name!r} is not supported.")
  209. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  210. if (quant_config is not None
  211. and act_fn_name in quant_config.get_scaled_act_names()):
  212. if intermediate_size is None:
  213. raise ValueError("intermediate_size must be specified for scaled "
  214. "activation functions.")
  215. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  216. params_dtype)
  217. return act_fn