activation.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """Custom activation functions."""
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from aphrodite._C import ops
  8. from aphrodite.modeling.layers.quantization import QuantizationConfig
  9. from aphrodite.modeling.megatron.parallel_state import (
  10. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  11. from aphrodite.modeling.megatron.utils import divide
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. class SiluAndMul(nn.Module):
  14. """An activation function for SwiGLU.
  15. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
  16. Shapes:
  17. x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
  18. return: (batch_size, seq_len, d) or (num_tokens, d)
  19. """
  20. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  21. """PyTorch-native implementation equivalent to forward()."""
  22. d = x.shape[-1] // 2
  23. return F.silu(x[..., :d]) * x[..., d:]
  24. def forward(self, x: torch.Tensor) -> torch.Tensor:
  25. d = x.shape[-1] // 2
  26. output_shape = (x.shape[:-1] + (d, ))
  27. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  28. ops.silu_and_mul(out, x)
  29. return out
  30. class NewGELU(nn.Module):
  31. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  32. """PyTorch-native implementation equivalent to forward()."""
  33. c = math.sqrt(2.0 / math.pi)
  34. return 0.5 * x * (1.0 + torch.tanh(c *
  35. (x + 0.044715 * torch.pow(x, 3.0))))
  36. def forward(self, x: torch.Tensor) -> torch.Tensor:
  37. out = torch.empty_like(x)
  38. ops.gelu_new(out, x)
  39. return out
  40. class FastGELU(nn.Module):
  41. def _forward(self, x: torch.Tensor) -> torch.Tensor:
  42. """PyTorch-native implementation equivalent to forward()."""
  43. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
  44. (1.0 + 0.044715 * x * x)))
  45. def forward(self, x: torch.Tensor) -> torch.Tensor:
  46. out = torch.empty_like(x)
  47. ops.gelu_fast(out, x)
  48. return out
  49. class ScaledActivation(nn.Module):
  50. """An activation function with post-scale parameters.
  51. This is used for some quantization methods like AWQ.
  52. """
  53. def __init__(
  54. self,
  55. act_module: nn.Module,
  56. intermediate_size: int,
  57. input_is_parallel: bool = True,
  58. params_dtype: Optional[torch.dtype] = None,
  59. ):
  60. super().__init__()
  61. self.act = act_module
  62. self.input_is_parallel = input_is_parallel
  63. if input_is_parallel:
  64. tp_size = get_tensor_model_parallel_world_size()
  65. intermediate_size_per_partition = divide(intermediate_size,
  66. tp_size)
  67. else:
  68. intermediate_size_per_partition = intermediate_size
  69. if params_dtype is None:
  70. params_dtype = torch.get_default_dtype()
  71. self.scales = nn.Parameter(
  72. torch.empty(intermediate_size_per_partition, dtype=params_dtype))
  73. set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
  74. def forward(self, x: torch.Tensor) -> torch.Tensor:
  75. return self.act(x) / self.scales
  76. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
  77. param_data = param.data
  78. if self.input_is_parallel:
  79. tp_rank = get_tensor_model_parallel_rank()
  80. shard_size = param_data.shape[0]
  81. start_idx = tp_rank * shard_size
  82. loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
  83. assert param_data.shape == loaded_weight.shape
  84. param_data.copy_(loaded_weight)
  85. _ACTIVATION_REGISTRY = {
  86. "gelu": nn.GELU(),
  87. "gelu_fast": FastGELU(),
  88. "gelu_new": NewGELU(),
  89. "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
  90. "relu": nn.ReLU(),
  91. }
  92. def get_act_fn(
  93. act_fn_name: str,
  94. quant_config: Optional[QuantizationConfig] = None,
  95. intermediate_size: Optional[int] = None,
  96. input_is_parallel: bool = True,
  97. params_dtype: Optional[torch.dtype] = None,
  98. ) -> nn.Module:
  99. """Get an activation function by name."""
  100. act_fn_name = act_fn_name.lower()
  101. if act_fn_name not in _ACTIVATION_REGISTRY:
  102. raise ValueError(
  103. f"Activation function {act_fn_name!r} is not supported.")
  104. act_fn = _ACTIVATION_REGISTRY[act_fn_name]
  105. if (quant_config is not None
  106. and act_fn_name in quant_config.get_scaled_act_names()):
  107. if intermediate_size is None:
  108. raise ValueError("intermediate_size must be specified for scaled "
  109. "activation functions.")
  110. return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
  111. params_dtype)
  112. return act_fn