1
0

k_activations.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
  2. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  3. #
  4. # This source code is licensed under the BSD license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import math
  7. from enum import Enum
  8. from typing import Optional
  9. import triton
  10. import triton.language as tl
  11. _sqrt2pi = math.sqrt(2.0 / math.pi)
  12. _sqrt1_2 = math.sqrt(1.0 / 2)
  13. _gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
  14. class Activation(str, Enum):
  15. SquaredReLU = "squared_relu"
  16. GeLU = "gelu"
  17. GeLUApprox = "gelu_approx"
  18. LeakyReLU = "leaky_relu"
  19. ReLU = "relu"
  20. def get_triton_activation_kernel(activation: Optional[Activation]):
  21. return (
  22. {
  23. Activation.ReLU: relu,
  24. Activation.LeakyReLU: leaky_relu,
  25. Activation.GeLU: gelu,
  26. Activation.GeLUApprox: gelu_approx,
  27. Activation.SquaredReLU: squared_relu,
  28. }[activation]
  29. if activation
  30. else None
  31. )
  32. def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
  33. return (
  34. {
  35. Activation.ReLU: relu_grad,
  36. Activation.LeakyReLU: leaky_relu_grad,
  37. Activation.GeLU: gelu_grad,
  38. Activation.GeLUApprox: gelu_approx_grad,
  39. Activation.SquaredReLU: squared_relu_grad,
  40. }[activation]
  41. if activation
  42. else None
  43. )
  44. @triton.jit
  45. def tanh(x):
  46. # Tanh is just a scaled sigmoid
  47. return 2 * tl.sigmoid(2 * x) - 1
  48. @triton.jit
  49. def cosh(x):
  50. exp_x = tl.exp(x)
  51. return (exp_x + 1.0 / exp_x) * 0.5
  52. # a Triton implementation of the most used activations
  53. # See for instance http://arxiv.org/abs/1606.08415 for an overview
  54. # ReLU
  55. @triton.jit
  56. def relu(x):
  57. """
  58. ReLU_ activation function
  59. .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
  60. """
  61. zero = 0.0
  62. return tl.where(x >= 0, x, zero.to(x.dtype))
  63. @triton.jit
  64. def relu_grad(x):
  65. # ReLU is different from other activations
  66. # in that it does not require the input to retrospectively compute its gradient
  67. # here the input is the downstream gradient, and we return the upstream gradient directly
  68. zero = 0.0
  69. one = 1.0
  70. return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
  71. @triton.jit
  72. def squared_relu(x):
  73. """
  74. Squared ReLU activation, as proposed in the Primer_ paper.
  75. .. _Primer: https://arxiv.org/abs/2109.08668
  76. """
  77. x_ = relu(x)
  78. return (x_ * x_).to(x.dtype)
  79. @triton.jit
  80. def squared_relu_grad(x):
  81. return tl.where(x >= 0, 2.0 * x, 0.0)
  82. # Leaky ReLU
  83. @triton.jit
  84. def leaky_relu(x):
  85. """
  86. LeakyReLU_ activation
  87. .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
  88. """
  89. scale = 0.01 + 0.0
  90. scale = scale.to(x.dtype)
  91. return tl.where(x >= 0, x, scale * x)
  92. @triton.jit
  93. def leaky_relu_grad(x):
  94. min_grad = 0.01
  95. max_grad = 1
  96. min_grad = min_grad.to(x.dtype)
  97. max_grad = max_grad.to(x.dtype)
  98. return tl.where(x >= 0, max_grad, min_grad)
  99. @triton.jit
  100. def gelu(x):
  101. """Gaussian Error Linear Unit (GELU)"""
  102. return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
  103. @triton.jit
  104. def gelu_grad(x):
  105. cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
  106. pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
  107. return cdf + x * pdf
  108. @triton.jit
  109. def gelu_approx(x):
  110. """
  111. GeLU_ activation - Gaussian error linear unit, with tanh approximation
  112. .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
  113. """
  114. return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
  115. @triton.jit
  116. def gelu_approx_grad(x):
  117. # CREDITS: Fast implementation proposed in
  118. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
  119. tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  120. return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
  121. 1 + tanh_out
  122. )