activation.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization import QuantizationConfig
  11. from aphrodite.modeling._custom_op import CustomOp
  12. class SiluAndMul(CustomOp):
  13. """An activation function for SwiGLU.
  14. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  15. Shapes:
  16. x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
  17. return: (num_tokens, d) or (batch_size, seq_len, d)
  18. """
  19. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  20. """PyTorch-native implementation equivalent to forward()."""
  21. d = x.shape[-1] // 2
  22. return F.silu(x[..., :d]) * x[..., d:]
  23. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  24. from aphrodite import _custom_ops as ops
  25. d = x.shape[-1] // 2
  26. output_shape = (x.shape[:-1] + (d, ))
  27. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  28. ops.silu_and_mul(out, x)
  29. return out
  30. class GeluAndMul(CustomOp):
  31. """An activation function for GeGLU.
  32. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  33. Shapes:
  34. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  35. return: (batch_size, seq_len, d) or (num_tokens, d)
  36. """
  37. def __init__(self, approximate: str = "none"):
  38. super().__init__()
  39. self.approximate = approximate
  40. if approximate not in ("none", "tanh"):
  41. raise ValueError(f"Unknown approximate mode: {approximate}")
  42. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  43. """PyTorch-native implementation equivalent to forward()."""
  44. d = x.shape[-1] // 2
  45. return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
  46. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  47. from aphrodite import _custom_ops as ops
  48. d = x.shape[-1] // 2
  49. output_shape = (x.shape[:-1] + (d, ))
  50. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  51. if self.approximate == "none":
  52. ops.gelu_and_mul(out, x)
  53. elif self.approximate == "tanh":
  54. ops.gelu_tanh_and_mul(out, x)
  55. return out
  56. def extra_repr(self) -> str:
  57. return f'approximate={repr(self.approximate)}'
  58. class NewGELU(CustomOp):
  59. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  60. """PyTorch-native implementation equivalent to forward()."""
  61. c = math.sqrt(2.0 / math.pi)
  62. return 0.5 * x * (1.0 + torch.tanh(c *
  63. (x + 0.044715 * torch.pow(x, 3.0))))
  64. def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
  65. from aphrodite import _custom_ops as ops
  66. out = torch.empty_like(x)
  67. ops.gelu_new(out, x)
  68. return out
  69. class FastGELU(CustomOp):
  70. def forward_native(self, x: torch.Tensor) -> torch.Tensor:
  71. """PyTorch-native implementation equivalent to forward()."""
  72. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  73. (1.0 + 0.044715 * x * x)))
  74. def forward(self, x: torch.Tensor) -> torch.Tensor:
  75. from aphrodite import _custom_ops as ops
  76. out = torch.empty_like(x)
  77. ops.gelu_fast(out, x)
  78. return out
  79. class ScaledActivation(nn.Module):
  80. """An activation function with post-scale parameters.
  81. This is used for some quantization methods like AWQ.
  82. """
  83. def __init__(
  84. self,
  85. act_module: nn.Module,
  86. intermediate_size: int,
  87. input_is_parallel: bool = True,
  88. params_dtype: Optional[torch.dtype] = None,
  89. ):
  90. super().__init__()
  91. self.act = act_module
  92. self.input_is_parallel = input_is_parallel
  93. if input_is_parallel:
  94. tp_size = get_tensor_model_parallel_world_size()
  95. intermediate_size_per_partition = divide(intermediate_size,
  96. tp_size)
  97. else:
  98. intermediate_size_per_partition = intermediate_size
  99. if params_dtype is None:
  100. params_dtype = torch.get_default_dtype()
  101. self.scales = nn.Parameter(
  102. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  103. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  104. def forward(self, x: torch.Tensor) -> torch.Tensor:
  105. return self.act(x) / self.scales
  106. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  107. param_data = param.data
  108. if self.input_is_parallel:
  109. tp_rank = get_tensor_model_parallel_rank()
  110. shard_size = param_data.shape[0]
  111. start_idx = tp_rank * shard_size
  112. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  113. assert param_data.shape == loaded_weight.shape
  114. param_data.copy_(loaded_weight)
  115. _ACTIVATION_REGISTRY = {
  116. "gelu": nn.GELU(),
  117. "gelu_fast": FastGELU(),
  118. "gelu_new": NewGELU(),
  119. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  120. "relu": nn.ReLU(),
  121. }
  122. def get_act_fn(
  123. act_fn_name: str,
  124. quant_config: Optional[QuantizationConfig] = None,
  125. intermediate_size: Optional[int] = None,
  126. input_is_parallel: bool = True,
  127. params_dtype: Optional[torch.dtype] = None,
  128. ) -> nn.Module:
  129. """Get an activation function by name."""
  130. act_fn_name = act_fn_name.lower()
  131. if act_fn_name not in _ACTIVATION_REGISTRY:
  132. raise ValueError(
  133. f"Activation function {act_fn_name!r} is not supported.")
  134. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  135. if (quant_config is not None
  136. and act_fn_name in quant_config.get_scaled_act_names()):
  137. if intermediate_size is None:
  138. raise ValueError("intermediate_size must be specified for scaled "
  139. "activation functions.")
  140. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  141. params_dtype)
  142. return act_fn