Browse Source

[MLP] Implement SwiGLU with torch jiterator

Tri Dao 1 year ago
parent
commit
3557e0bb8f
2 changed files with 41 additions and 1 deletions
  1. 10 1
      flash_attn/modules/mlp.py
  2. 31 0
      flash_attn/ops/activations.py

+ 10 - 1
flash_attn/modules/mlp.py

@@ -1,10 +1,16 @@
-# Copyright (c) 2022, Tri Dao.
+# Copyright (c) 2023, Tri Dao.
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from torch.distributed import ProcessGroup
 from torch.distributed import ProcessGroup
 
 
+
+try:
+    from flash_attn.ops.activations import swiglu
+except ImportError:
+    swiglu = None
+
 try:
 try:
     from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
     from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 except ImportError:
 except ImportError:
@@ -120,6 +126,9 @@ class GatedMlp(nn.Module):
         y = self.fc1(x)
         y = self.fc1(x)
         if self.activation == F.sigmoid:  # Special case for GLU
         if self.activation == F.sigmoid:  # Special case for GLU
             y = F.glu(y, dim=-1)
             y = F.glu(y, dim=-1)
+        elif self.activation == F.silu and swiglu is not None:  # Special case for SwiGLU
+            y, gate = y.chunk(2, dim=-1)
+            y = swiglu(gate, y)
         else:
         else:
             y, gate = y.chunk(2, dim=-1)
             y, gate = y.chunk(2, dim=-1)
             y = y * self.activation(gate)
             y = y * self.activation(gate)

+ 31 - 0
flash_attn/ops/activations.py

@@ -102,3 +102,34 @@ def sqrelu_fwd(x):
 @torch.jit.script
 @torch.jit.script
 def sqrelu_bwd(g, x):
 def sqrelu_bwd(g, x):
     return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
     return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
+
+
+swiglu_fwd_codestring = """
+template <typename T> T swiglu_fwd(T x, T y) {
+    return float(x) * float(y) / (1.0f + ::exp(-float(x)));
+}
+"""
+swiglu_bwd_codestring = """
+template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
+    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
+    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
+    dy = float(x) * x_sigmoid * float(g);
+}
+"""
+swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
+swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
+
+
+class SwiGLUFunction(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, x, y):
+        ctx.save_for_backward(x, y)
+        return swiglu_fwd(x, y)
+
+    @staticmethod
+    def backward(ctx, dout):
+        x, y = ctx.saved_tensors
+        return swiglu_bwd(x, y, dout)
+
+swiglu = SwiGLUFunction.apply