1
0

activation.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size)
  9. from aphrodite.modeling._custom_op import CustomOp
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization import QuantizationConfig
  12. class SiluAndMul(CustomOp):
  13. """An activation function for SwiGLU.
  14. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  15. Shapes:
  16. x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
  17. return: (num_tokens, d) or (batch_size, seq_len, d)
  18. """
  19. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  20. """PyTorch-native implementation equivalent to forward()."""
  21. d = x.shape[-1] // 2
  22. return F.silu(x[..., :d]) * x[..., d:]
  23. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  24. from aphrodite import _custom_ops as ops
  25. d = x.shape[-1] // 2
  26. output_shape = (x.shape[:-1] + (d, ))
  27. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  28. ops.silu_and_mul(out, x)
  29. return out
  30. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  31. from aphrodite._ipex_ops import ipex_ops as ops
  32. d = x.shape[-1] // 2
  33. output_shape = (x.shape[:-1] + (d, ))
  34. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  35. ops.silu_and_mul(out, x)
  36. return out
  37. class GeluAndMul(CustomOp):
  38. """An activation function for GeGLU.
  39. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  40. Shapes:
  41. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  42. return: (batch_size, seq_len, d) or (num_tokens, d)
  43. """
  44. def __init__(self, approximate: str = "none"):
  45. super().__init__()
  46. self.approximate = approximate
  47. if approximate not in ("none", "tanh"):
  48. raise ValueError(f"Unknown approximate mode: {approximate}")
  49. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  50. """PyTorch-native implementation equivalent to forward()."""
  51. d = x.shape[-1] // 2
  52. return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
  53. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  54. from aphrodite import _custom_ops as ops
  55. d = x.shape[-1] // 2
  56. output_shape = (x.shape[:-1] + (d, ))
  57. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  58. if self.approximate == "none":
  59. ops.gelu_and_mul(out, x)
  60. elif self.approximate == "tanh":
  61. ops.gelu_tanh_and_mul(out, x)
  62. return out
  63. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  64. from aphrodite._ipex_ops import ipex_ops as ops
  65. d = x.shape[-1] // 2
  66. output_shape = (x.shape[:-1] + (d, ))
  67. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  68. if self.approximate == "none":
  69. ops.gelu_and_mul(out, x)
  70. elif self.approximate == "tanh":
  71. ops.gelu_tanh_and_mul(out, x)
  72. return out
  73. def extra_repr(self) -> str:
  74. return f'approximate={repr(self.approximate)}'
  75. class NewGELU(CustomOp):
  76. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  77. """PyTorch-native implementation equivalent to forward()."""
  78. c = math.sqrt(2.0 / math.pi)
  79. return 0.5 * x * (1.0 + torch.tanh(c *
  80. (x + 0.044715 * torch.pow(x, 3.0))))
  81. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  82. from aphrodite import _custom_ops as ops
  83. out = torch.empty_like(x)
  84. ops.gelu_new(out, x)
  85. return out
  86. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  87. from aphrodite._ipex_ops import ipex_ops as ops
  88. out = torch.empty_like(x)
  89. ops.gelu_new(out, x)
  90. return out
  91. class FastGELU(CustomOp):
  92. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  93. """PyTorch-native implementation equivalent to forward()."""
  94. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  95. (1.0 + 0.044715 * x * x)))
  96. def forward(self, x: torch.Tensor) -> torch.Tensor:
  97. from aphrodite import _custom_ops as ops
  98. out = torch.empty_like(x)
  99. ops.gelu_fast(out, x)
  100. return out
  101. def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
  102. from aphrodite._ipex_ops import ipex_ops as ops
  103. out = torch.empty_like(x)
  104. ops.gelu_fast(out, x)
  105. return out
  106. class QuickGELU(CustomOp):
  107. # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
  108. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  109. """PyTorch-native implementation equivalent to forward()."""
  110. return x * torch.sigmoid(1.702 * x)
  111. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  112. from aphrodite import _custom_ops as ops
  113. out = torch.empty_like(x)
  114. ops.gelu_quick(out, x)
  115. return out
  116. class ReLUSquaredActivation(CustomOp):
  117. """
  118. Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
  119. """
  120. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  121. """PyTorch-native implementation equivalent to forward()."""
  122. return torch.square(F.relu(x))
  123. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  124. return self.forward_native(x)
  125. class ScaledActivation(nn.Module):
  126. """An activation function with post-scale parameters.
  127. This is used for some quantization methods like AWQ.
  128. """
  129. def __init__(
  130. self,
  131. act_module: nn.Module,
  132. intermediate_size: int,
  133. input_is_parallel: bool = True,
  134. params_dtype: Optional[torch.dtype] = None,
  135. ):
  136. super().__init__()
  137. self.act = act_module
  138. self.input_is_parallel = input_is_parallel
  139. if input_is_parallel:
  140. tp_size = get_tensor_model_parallel_world_size()
  141. intermediate_size_per_partition = divide(intermediate_size,
  142. tp_size)
  143. else:
  144. intermediate_size_per_partition = intermediate_size
  145. if params_dtype is None:
  146. params_dtype = torch.get_default_dtype()
  147. self.scales = nn.Parameter(
  148. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  149. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  150. def forward(self, x: torch.Tensor) -> torch.Tensor:
  151. return self.act(x) / self.scales
  152. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  153. param_data = param.data
  154. if self.input_is_parallel:
  155. tp_rank = get_tensor_model_parallel_rank()
  156. shard_size = param_data.shape[0]
  157. start_idx = tp_rank * shard_size
  158. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  159. assert param_data.shape == loaded_weight.shape
  160. param_data.copy_(loaded_weight)
  161. _ACTIVATION_REGISTRY = {
  162. "gelu": nn.GELU(),
  163. "gelu_fast": FastGELU(),
  164. "gelu_new": NewGELU(),
  165. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  166. "relu": nn.ReLU(),
  167. "relu2": ReLUSquaredActivation(),
  168. "quick_gelu": QuickGELU(),
  169. }
  170. def get_act_fn(
  171. act_fn_name: str,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. intermediate_size: Optional[int] = None,
  174. input_is_parallel: bool = True,
  175. params_dtype: Optional[torch.dtype] = None,
  176. ) -> nn.Module:
  177. """Get an activation function by name."""
  178. act_fn_name = act_fn_name.lower()
  179. if act_fn_name not in _ACTIVATION_REGISTRY:
  180. raise ValueError(
  181. f"Activation function {act_fn_name!r} is not supported.")
  182. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  183. if (quant_config is not None
  184. and act_fn_name in quant_config.get_scaled_act_names()):
  185. if intermediate_size is None:
  186. raise ValueError("intermediate_size must be specified for scaled "
  187. "activation functions.")
  188. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  189. params_dtype)
  190. return act_fn