activations.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. # 1/sqrt(2*pi)-> 0.3989423
  7. # 1/sqrt(2) -> 0.70710678
  8. # sqrt(2/pi) -> 0.79788456
  9. # this function is tanh approximation of gelu
  10. # actual gelu is:
  11. # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
  12. @torch.jit.script
  13. def bias_gelu(y, bias):
  14. x = bias + y
  15. return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
  16. # gradient of tanh approximation of gelu
  17. # gradient of actual gelu is:
  18. # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
  19. @torch.jit.script
  20. def bias_gelu_back(g, y, bias):
  21. """Assume that y has shape (B, D) and bias has shape (D)"""
  22. x = bias + y
  23. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  24. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  25. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
  26. 1 + tanh_out
  27. )
  28. grad_y = ff * g
  29. return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
  30. class GeLUFunction(torch.autograd.Function):
  31. @staticmethod
  32. # bias is an optional argument
  33. def forward(ctx, input, bias):
  34. ctx.save_for_backward(input, bias)
  35. return bias_gelu(input, bias)
  36. @staticmethod
  37. def backward(ctx, grad_output):
  38. input, bias = ctx.saved_tensors
  39. tmp = bias_gelu_back(grad_output, input, bias)
  40. return tmp, tmp
  41. bias_gelu_impl = GeLUFunction.apply
  42. # this function is tanh approximation of gelu
  43. # actual gelu is:
  44. # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
  45. @torch.jit.script
  46. def gelu_fwd(x):
  47. return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
  48. # gradient of tanh approximation of gelu
  49. # gradient of actual gelu is:
  50. # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
  51. @torch.jit.script
  52. def gelu_bwd(g, x):
  53. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  54. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  55. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
  56. 1 + tanh_out
  57. )
  58. return (ff * g).to(dtype=x.dtype)
  59. class FastGeLUFunction(torch.autograd.Function):
  60. @staticmethod
  61. # bias is an optional argument
  62. def forward(ctx, input):
  63. ctx.save_for_backward(input)
  64. return gelu_fwd(input)
  65. @staticmethod
  66. def backward(ctx, grad_output):
  67. (input,) = ctx.saved_tensors
  68. tmp = gelu_bwd(grad_output, input)
  69. return tmp
  70. fast_gelu_impl = FastGeLUFunction.apply
  71. @torch.jit.script
  72. def relu_bwd(g, x):
  73. return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
  74. @torch.jit.script
  75. def sqrelu_fwd(x):
  76. r = F.relu(x)
  77. return (r * r).to(dtype=x.dtype)
  78. @torch.jit.script
  79. def sqrelu_bwd(g, x):
  80. return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
  81. swiglu_fwd_codestring = """
  82. template <typename T> T swiglu_fwd(T x, T y) {
  83. return float(x) * float(y) / (1.0f + ::exp(-float(x)));
  84. }
  85. """
  86. swiglu_bwd_codestring = """
  87. template <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
  88. float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
  89. dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
  90. dy = float(x) * x_sigmoid * float(g);
  91. }
  92. """
  93. swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
  94. swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
  95. class SwiGLUFunction(torch.autograd.Function):
  96. @staticmethod
  97. def forward(ctx, x, y):
  98. ctx.save_for_backward(x, y)
  99. return swiglu_fwd(x, y)
  100. @staticmethod
  101. def backward(ctx, dout):
  102. x, y = ctx.saved_tensors
  103. return swiglu_bwd(x, y, dout)
  104. swiglu = SwiGLUFunction.apply