activation.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite._C import ops
  8. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  9. get_tensor_model_parallel_world_size)
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization import QuantizationConfig
  12. class SiluAndMul(nn.Module):
  13. """An activation function for SwiGLU.
  14. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  15. Shapes:
  16. x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
  17. return: (num_tokens, d) or (batch_size, seq_len, d)
  18. """
  19. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  20. """PyTorch-native implementation equivalent to forward()."""
  21. d = x.shape[-1] // 2
  22. return F.silu(x[..., :d]) * x[..., d:]
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. d = x.shape[-1] // 2
  25. output_shape = (x.shape[:-1] + (d, ))
  26. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  27. ops.silu_and_mul(out, x)
  28. return out
  29. class GeluAndMul(nn.Module):
  30. """An activation function for GeGLU.
  31. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  32. Shapes:
  33. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  34. return: (batch_size, seq_len, d) or (num_tokens, d)
  35. """
  36. def __init__(self, approximate: str = "none"):
  37. super().__init__()
  38. self.approximate = approximate
  39. if approximate not in ("none", "tanh"):
  40. raise ValueError(f"Unknown approximate mode: {approximate}")
  41. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  42. """PyTorch-native implementation equivalent to forward()."""
  43. d = x.shape[-1] // 2
  44. return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
  45. def forward(self, x: torch.Tensor) -> torch.Tensor:
  46. d = x.shape[-1] // 2
  47. output_shape = (x.shape[:-1] + (d, ))
  48. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  49. if self.approximate == "none":
  50. ops.gelu_and_mul(out, x)
  51. elif self.approximate == "tanh":
  52. ops.gelu_tanh_and_mul(out, x)
  53. return out
  54. def extra_repr(self) -> str:
  55. return f'approximate={repr(self.approximate)}'
  56. class NewGELU(nn.Module):
  57. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  58. """PyTorch-native implementation equivalent to forward()."""
  59. c = math.sqrt(2.0 / math.pi)
  60. return 0.5 * x * (1.0 + torch.tanh(c *
  61. (x + 0.044715 * torch.pow(x, 3.0))))
  62. def forward(self, x: torch.Tensor) -> torch.Tensor:
  63. out = torch.empty_like(x)
  64. ops.gelu_new(out, x)
  65. return out
  66. class FastGELU(nn.Module):
  67. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  68. """PyTorch-native implementation equivalent to forward()."""
  69. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  70. (1.0 + 0.044715 * x * x)))
  71. def forward(self, x: torch.Tensor) -> torch.Tensor:
  72. out = torch.empty_like(x)
  73. ops.gelu_fast(out, x)
  74. return out
  75. class ScaledActivation(nn.Module):
  76. """An activation function with post-scale parameters.
  77. This is used for some quantization methods like AWQ.
  78. """
  79. def __init__(
  80. self,
  81. act_module: nn.Module,
  82. intermediate_size: int,
  83. input_is_parallel: bool = True,
  84. params_dtype: Optional[torch.dtype] = None,
  85. ):
  86. super().__init__()
  87. self.act = act_module
  88. self.input_is_parallel = input_is_parallel
  89. if input_is_parallel:
  90. tp_size = get_tensor_model_parallel_world_size()
  91. intermediate_size_per_partition = divide(intermediate_size,
  92. tp_size)
  93. else:
  94. intermediate_size_per_partition = intermediate_size
  95. if params_dtype is None:
  96. params_dtype = torch.get_default_dtype()
  97. self.scales = nn.Parameter(
  98. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  99. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  100. def forward(self, x: torch.Tensor) -> torch.Tensor:
  101. return self.act(x) / self.scales
  102. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  103. param_data = param.data
  104. if self.input_is_parallel:
  105. tp_rank = get_tensor_model_parallel_rank()
  106. shard_size = param_data.shape[0]
  107. start_idx = tp_rank * shard_size
  108. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  109. assert param_data.shape == loaded_weight.shape
  110. param_data.copy_(loaded_weight)
  111. _ACTIVATION_REGISTRY = {
  112. "gelu": nn.GELU(),
  113. "gelu_fast": FastGELU(),
  114. "gelu_new": NewGELU(),
  115. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  116. "relu": nn.ReLU(),
  117. }
  118. def get_act_fn(
  119. act_fn_name: str,
  120. quant_config: Optional[QuantizationConfig] = None,
  121. intermediate_size: Optional[int] = None,
  122. input_is_parallel: bool = True,
  123. params_dtype: Optional[torch.dtype] = None,
  124. ) -> nn.Module:
  125. """Get an activation function by name."""
  126. act_fn_name = act_fn_name.lower()
  127. if act_fn_name not in _ACTIVATION_REGISTRY:
  128. raise ValueError(
  129. f"Activation function {act_fn_name!r} is not supported.")
  130. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  131. if (quant_config is not None
  132. and act_fn_name in quant_config.get_scaled_act_names()):
  133. if intermediate_size is None:
  134. raise ValueError("intermediate_size must be specified for scaled "
  135. "activation functions.")
  136. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  137. params_dtype)
  138. return act_fn