activation.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite._C import ops
  8. from aphrodite.distributed import (
  9. divide,
  10. get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size,
  12. )
  13. from aphrodite.quantization import QuantizationConfig
  14. from aphrodite.modeling.utils import set_weight_attrs
  15. class SiluAndMul(nn.Module):
  16. """An activation function for SwiGLU.
  17. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  18. Shapes:
  19. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  20. return: (batch_size, seq_len, d) or (num_tokens, d)
  21. """
  22. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  23. """PyTorch-native implementation equivalent to forward()."""
  24. d = x.shape[-1] // 2
  25. return F.silu(x[..., :d]) * x[..., d:]
  26. def forward(self, x: torch.Tensor) -> torch.Tensor:
  27. d = x.shape[-1] // 2
  28. output_shape = (x.shape[:-1] + (d, ))
  29. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  30. ops.silu_and_mul(out, x)
  31. return out
  32. class GeluAndMul(nn.Module):
  33. """An activation function for GeGLU.
  34. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  35. Shapes:
  36. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  37. return: (batch_size, seq_len, d) or (num_tokens, d)
  38. """
  39. def __init__(self, approximate: str = "none"):
  40. super().__init__()
  41. self.approximate = approximate
  42. if approximate not in ("none", "tanh"):
  43. raise ValueError(f"Invalid approximate value: {approximate}")
  44. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  45. """PyTorch-native implementation equivalent to forward()."""
  46. d = x.shape[-1] // 2
  47. return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
  48. def forward(self, x: torch.Tensor) -> torch.Tensor:
  49. d = x.shape[-1] // 2
  50. output_shape = (x.shape[:-1] + (d, ))
  51. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  52. if self.approximate == "none":
  53. ops.gelu_and_mul(out, x)
  54. elif self.approximate == "tanh":
  55. ops.gelu_tanh_and_mul(out, x)
  56. return out
  57. class NewGELU(nn.Module):
  58. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  59. """PyTorch-native implementation equivalent to forward()."""
  60. c = math.sqrt(2.0 / math.pi)
  61. return 0.5 * x * (1.0 + torch.tanh(c *
  62. (x + 0.044715 * torch.pow(x, 3.0))))
  63. def forward(self, x: torch.Tensor) -> torch.Tensor:
  64. out = torch.empty_like(x)
  65. ops.gelu_new(out, x)
  66. return out
  67. class FastGELU(nn.Module):
  68. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  69. """PyTorch-native implementation equivalent to forward()."""
  70. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  71. (1.0 + 0.044715 * x * x)))
  72. def forward(self, x: torch.Tensor) -> torch.Tensor:
  73. out = torch.empty_like(x)
  74. ops.gelu_fast(out, x)
  75. return out
  76. class ScaledActivation(nn.Module):
  77. """An activation function with post-scale parameters.
  78. This is used for some quantization methods like AWQ.
  79. """
  80. def __init__(
  81. self,
  82. act_module: nn.Module,
  83. intermediate_size: int,
  84. input_is_parallel: bool = True,
  85. params_dtype: Optional[torch.dtype] = None,
  86. ):
  87. super().__init__()
  88. self.act = act_module
  89. self.input_is_parallel = input_is_parallel
  90. if input_is_parallel:
  91. tp_size = get_tensor_model_parallel_world_size()
  92. intermediate_size_per_partition = divide(intermediate_size,
  93. tp_size)
  94. else:
  95. intermediate_size_per_partition = intermediate_size
  96. if params_dtype is None:
  97. params_dtype = torch.get_default_dtype()
  98. self.scales = nn.Parameter(
  99. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  100. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  101. def forward(self, x: torch.Tensor) -> torch.Tensor:
  102. return self.act(x) / self.scales
  103. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  104. param_data = param.data
  105. if self.input_is_parallel:
  106. tp_rank = get_tensor_model_parallel_rank()
  107. shard_size = param_data.shape[0]
  108. start_idx = tp_rank * shard_size
  109. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  110. assert param_data.shape == loaded_weight.shape
  111. param_data.copy_(loaded_weight)
  112. _ACTIVATION_REGISTRY = {
  113. "gelu": nn.GELU(),
  114. "gelu_fast": FastGELU(),
  115. "gelu_new": NewGELU(),
  116. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  117. "relu": nn.ReLU(),
  118. }
  119. def get_act_fn(
  120. act_fn_name: str,
  121. quant_config: Optional[QuantizationConfig] = None,
  122. intermediate_size: Optional[int] = None,
  123. input_is_parallel: bool = True,
  124. params_dtype: Optional[torch.dtype] = None,
  125. ) -> nn.Module:
  126. """Get an activation function by name."""
  127. act_fn_name = act_fn_name.lower()
  128. if act_fn_name not in _ACTIVATION_REGISTRY:
  129. raise ValueError(
  130. f"Activation function {act_fn_name!r} is not supported.")
  131. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  132. if (quant_config is not None
  133. and act_fn_name in quant_config.get_scaled_act_names()):
  134. if intermediate_size is None:
  135. raise ValueError("intermediate_size must be specified for scaled "
  136. "activation functions.")
  137. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  138. params_dtype)
  139. return act_fn