123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- import math
- from enum import Enum
- from typing import Optional
- import triton
- import triton.language as tl
- _sqrt2pi = math.sqrt(2.0 / math.pi)
- _sqrt1_2 = math.sqrt(1.0 / 2)
- _gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
- class Activation(str, Enum):
- SquaredReLU = "squared_relu"
- GeLU = "gelu"
- GeLUApprox = "gelu_approx"
- LeakyReLU = "leaky_relu"
- ReLU = "relu"
- def get_triton_activation_kernel(activation: Optional[Activation]):
- return (
- {
- Activation.ReLU: relu,
- Activation.LeakyReLU: leaky_relu,
- Activation.GeLU: gelu,
- Activation.GeLUApprox: gelu_approx,
- Activation.SquaredReLU: squared_relu,
- }[activation]
- if activation
- else None
- )
- def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
- return (
- {
- Activation.ReLU: relu_grad,
- Activation.LeakyReLU: leaky_relu_grad,
- Activation.GeLU: gelu_grad,
- Activation.GeLUApprox: gelu_approx_grad,
- Activation.SquaredReLU: squared_relu_grad,
- }[activation]
- if activation
- else None
- )
- @triton.jit
- def tanh(x):
- # Tanh is just a scaled sigmoid
- return 2 * tl.sigmoid(2 * x) - 1
- @triton.jit
- def cosh(x):
- exp_x = tl.exp(x)
- return (exp_x + 1.0 / exp_x) * 0.5
- # a Triton implementation of the most used activations
- # See for instance http://arxiv.org/abs/1606.08415 for an overview
- # ReLU
- @triton.jit
- def relu(x):
- """
- ReLU_ activation function
- .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
- """
- zero = 0.0
- return tl.where(x >= 0, x, zero.to(x.dtype))
- @triton.jit
- def relu_grad(x):
- # ReLU is different from other activations
- # in that it does not require the input to retrospectively compute its gradient
- # here the input is the downstream gradient, and we return the upstream gradient directly
- zero = 0.0
- one = 1.0
- return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
- @triton.jit
- def squared_relu(x):
- """
- Squared ReLU activation, as proposed in the Primer_ paper.
- .. _Primer: https://arxiv.org/abs/2109.08668
- """
- x_ = relu(x)
- return (x_ * x_).to(x.dtype)
- @triton.jit
- def squared_relu_grad(x):
- return tl.where(x >= 0, 2.0 * x, 0.0)
- # Leaky ReLU
- @triton.jit
- def leaky_relu(x):
- """
- LeakyReLU_ activation
- .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
- """
- scale = 0.01 + 0.0
- scale = scale.to(x.dtype)
- return tl.where(x >= 0, x, scale * x)
- @triton.jit
- def leaky_relu_grad(x):
- min_grad = 0.01
- max_grad = 1
- min_grad = min_grad.to(x.dtype)
- max_grad = max_grad.to(x.dtype)
- return tl.where(x >= 0, max_grad, min_grad)
- @triton.jit
- def gelu(x):
- """Gaussian Error Linear Unit (GELU)"""
- return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
- @triton.jit
- def gelu_grad(x):
- cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
- pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
- return cdf + x * pdf
- @triton.jit
- def gelu_approx(x):
- """
- GeLU_ activation - Gaussian error linear unit, with tanh approximation
- .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
- """
- return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
- @triton.jit
- def gelu_approx_grad(x):
- # CREDITS: Fast implementation proposed in
- # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
- tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
- return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
- 1 + tanh_out
- )
|