|
@@ -102,3 +102,34 @@ def sqrelu_fwd(x):
|
|
@torch.jit.script
|
|
@torch.jit.script
|
|
def sqrelu_bwd(g, x):
|
|
def sqrelu_bwd(g, x):
|
|
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
|
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+swiglu_fwd_codestring = """
|
|
|
|
+template <typename T> T swiglu_fwd(T x, T y) {
|
|
|
|
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
|
|
|
+}
|
|
|
|
+"""
|
|
|
|
+swiglu_bwd_codestring = """
|
|
|
|
+template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
|
|
|
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
|
|
|
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
|
|
|
+ dy = float(x) * x_sigmoid * float(g);
|
|
|
|
+}
|
|
|
|
+"""
|
|
|
|
+swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
|
|
|
+swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class SwiGLUFunction(torch.autograd.Function):
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def forward(ctx, x, y):
|
|
|
|
+ ctx.save_for_backward(x, y)
|
|
|
|
+ return swiglu_fwd(x, y)
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def backward(ctx, dout):
|
|
|
|
+ x, y = ctx.saved_tensors
|
|
|
|
+ return swiglu_bwd(x, y, dout)
|
|
|
|
+
|
|
|
|
+swiglu = SwiGLUFunction.apply
|