Selaa lähdekoodia

Add GPT and ViT models

Tri Dao 2 vuotta sitten
vanhempi
commit
2e33fc8e36

+ 3 - 3
README.md

@@ -52,7 +52,7 @@ Our tentative roadmap:
 6. ~~[Jul 2022] Implement cross-attention~~[Done].
 7. ~~[Jul 2022] Support head dimension 128~~[Done].
 8. [Jul 2022] Support SM70 GPUs (V100).
-9. [Aug 2022] Fuse rotary embedding.
+9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
 10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
 
 ## Speedup and Memory Savings
@@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
 ## Citation
 If you use this codebase, or otherwise found our work valuable, please cite:
 ```
-@article{dao2022flashattention,
+@inproceedings{dao2022flashattention,
   title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
   author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
-  journal={arXiv preprint arXiv:2205.14135},
+  booktitle={Advances in Neural Information Processing Systems},
   year={2022}
 }
 ```

+ 1 - 1
csrc/fused_dense_lib/README.md

@@ -1,4 +1,4 @@
-This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
+This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
 (forward and backward), adapted from Apex's
 [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
 We make it work for bfloat16.

+ 1 - 1
csrc/layer_norm/README.md

@@ -1,4 +1,4 @@
-This CUDA extensions implements fused dropout + residual + LayerNorm, based on
+This CUDA extension implements fused dropout + residual + LayerNorm, based on
 Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
 We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
 ```sh

+ 174 - 0
flash_attn/models/gpt.py

@@ -0,0 +1,174 @@
+# Copyright (c) 2022, Tri Dao.
+
+import math
+from functools import partial
+
+from collections import namedtuple
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+
+from flash_attn.modules.mha import MHA
+from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
+from flash_attn.modules.block import Block
+from flash_attn.modules.embedding import GPT2Embeddings
+
+try:
+    from flash_attn.ops.layer_norm import dropout_add_layer_norm
+except ImportError:
+    dropout_add_layer_norm = None
+
+try:
+    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
+except ImportError:
+    FusedDenseSqreluDense = None
+
+
+def create_mixer_cls(config, layer_idx=None):
+    head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
+    softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
+    if config.scale_attn_by_inverse_layer_idx:
+        assert layer_idx is not None
+        softmax_scale /= float(layer_idx + 1)
+    dwconv = getattr(config, 'attn_dwconv', False)
+    rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
+    use_flash_attn = getattr(config, 'use_flash_attn', False)
+    fused_bias_fc = getattr(config, 'fused_bias_fc', False)
+    mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
+                        softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
+                        rotary_emb_dim=rotary_emb_dim,
+                        fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
+    return mixer_cls
+
+
+def create_mlp_cls(config, layer_idx=None):
+    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
+    fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
+    fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
+    assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
+    if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
+        mlp_cls = partial(Mlp, hidden_features=inner_dim,
+                          activation=partial(F.gelu, approximate='tanh'))
+    else:
+        mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
+        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
+        if isinstance(mlp_checkpoint_lvl, Sequence):
+            assert layer_idx is not None
+            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
+        if fused_dense_gelu_dense:
+            mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
+                              checkpoint_lvl=mlp_checkpoint_lvl)
+        elif fused_dense_sqrelu_dense:
+            assert FusedDenseSqreluDense is not None
+            mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
+                              checkpoint_lvl=mlp_checkpoint_lvl)
+        else:
+            raise RuntimeError('MLP type not supported')
+    return mlp_cls
+
+
+def create_block(config, layer_idx=None):
+    mixer_cls = create_mixer_cls(config, layer_idx)
+    mlp_cls = create_mlp_cls(config, layer_idx)
+    norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon)
+    block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
+                  prenorm=True, resid_dropout=config.resid_pdrop,
+                  fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
+    block.layer_idx = layer_idx
+    return block
+
+
+# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
+def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
+    if isinstance(module, nn.Linear):
+        nn.init.normal_(module.weight, std=initializer_range)
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+    elif isinstance(module, nn.Embedding):
+        nn.init.normal_(module.weight, std=initializer_range)
+
+    if rescale_prenorm_residual:
+        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
+        #
+        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+        for name, p in module.named_parameters():
+            if name in ["out_proj.weight", "fc2.weight"]:
+                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
+
+
+class GPT2Model(nn.Module):
+
+    def __init__(self, config: GPT2Config):
+        super().__init__()
+        self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
+        if self.pad_vocab_size_multiple_8:
+            if config.vocab_size % 8 != 0:
+                config.vocab_size += 8 - (config.vocab_size % 8)
+
+        self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
+                                         config.max_position_embeddings)
+        self.emb_drop = nn.Dropout(config.embd_pdrop)
+
+        # We change the order of residual and layer norm:
+        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
+        # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
+        # the main branch (output of LN). The model definition is unchanged, but the mapping of the
+        # nn.LayerNorm weights are changed.
+        # This is for performance reason: we can fuse dropout + add + layer_norm.
+        self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
+        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
+            raise ImportError('dropout_add_layer_norm is not installed')
+        # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
+        # is the final layer norm.
+        self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+
+        self.layers = nn.ModuleList([create_block(config, layer_idx=i)
+                                     for i in range(config.num_hidden_layers)])
+
+        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
+                           initializer_range=config.initializer_range))
+
+    def forward(self, input_ids, position_ids=None):
+        hidden_states = self.embeddings(input_ids, position_ids=position_ids)
+        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
+        if not self.fused_dropout_add_ln:
+            residual = self.emb_drop(hidden_states).float()
+            hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
+        else:
+            hidden_states, residual = dropout_add_layer_norm(
+                hidden_states, None, self.ln_0.weight, self.ln_0.bias,
+                self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
+                residual_in_fp32=True
+            )
+        for layer in self.layers:
+            hidden_states, residual = layer(hidden_states, residual)
+        return hidden_states
+
+
+class GPT2LMHeadModel(nn.Module):
+
+    def __init__(self, config: GPT2Config):
+        super().__init__()
+        self.transformer = GPT2Model(config)
+        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+        # Initialize weights and apply final processing
+        self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
+                           initializer_range=config.initializer_range))
+        self.tie_weights()
+
+    def tie_weights(self):
+        self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
+
+    def forward(self, input_ids, position_ids=None):
+        hidden_states = self.transformer(input_ids, position_ids=position_ids)
+        lm_logits = self.lm_head(hidden_states)
+        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
+        return CausalLMOutput(logits=lm_logits)

+ 249 - 0
flash_attn/models/vit.py

@@ -0,0 +1,249 @@
+# Copyright (c) 2022, Tri Dao.
+# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+import math
+from functools import partial
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from einops import rearrange
+
+from timm.models.helpers import named_apply
+from flash_attn.layers.patch_embed import PatchEmbed
+
+from flash_attn.modules.mha import MHA
+from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
+from flash_attn.modules.block import Block
+
+
+def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
+                     cross_attn=False):
+    mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
+                        dropout=attn_drop, fused_bias_fc=fused_bias_fc,
+                        use_flash_attn=use_flash_attn)
+    return mixer_cls
+
+
+def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
+    inner_dim = int(embed_dim * mlp_ratio)
+    if not fused_dense_gelu_dense:
+        mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
+    else:
+        mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim)
+    return mlp_cls
+
+
+def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path,
+                 norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense,
+                 fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False):
+    mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
+                                 cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
+    mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
+    block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
+                  prenorm=True, resid_dropout=drop_rate, drop_path=drop_path,
+                  fused_dropout_add_ln=fused_dropout_add_ln)
+    return block
+
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer
+    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+        - https://arxiv.org/abs/2010.11929
+    """
+    def __init__(
+            self,
+            img_size=224,
+            patch_size=16,
+            in_chans=3,
+            num_classes=1000,
+            global_pool='token',
+            embed_dim=768,
+            depth=12,
+            num_heads=12,
+            mlp_ratio=4.,
+            qkv_bias=True,
+            init_values=None,
+            class_token=True,
+            no_embed_class=False,
+            pre_norm=False,
+            fc_norm=None,
+            drop_rate=0.,
+            attn_drop_rate=0.,
+            drop_path_rate=0.,
+            weight_init='',
+            embed_layer=PatchEmbed,
+            norm_layer=None,
+            act_layer=None,
+            use_flash_attn=False,
+            fused_bias_fc=False,
+            fused_dense_gelu_dense=False,
+            fused_dropout_add_ln=False,
+    ):
+        """
+        Args:
+            img_size (int, tuple): input image size
+            patch_size (int, tuple): patch size
+            in_chans (int): number of input channels
+            num_classes (int): number of classes for classification head
+            global_pool (str): type of global pooling for final sequence (default: 'token')
+            embed_dim (int): embedding dimension
+            depth (int): depth of transformer
+            num_heads (int): number of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            init_values: (float): layer-scale init values
+            class_token (bool): use class token
+            fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
+            drop_rate (float): dropout rate
+            attn_drop_rate (float): attention dropout rate
+            drop_path_rate (float): stochastic depth rate
+            weight_init (str): weight init scheme
+            embed_layer (nn.Module): patch embedding layer
+            norm_layer: (nn.Module): normalization layer
+            act_layer: (nn.Module): MLP activation layer
+        """
+        super().__init__()
+        assert global_pool == 'token', 'Only support pooling with CLS token'
+        assert class_token
+        assert init_values is None, 'LayerScale is not supported yet'
+        assert weight_init == ''
+        assert fc_norm is None
+        # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
+        assert not pre_norm
+        use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+        act_layer = act_layer or nn.GELU
+
+        self.num_classes = num_classes
+        self.global_pool = global_pool
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.num_prefix_tokens = 1 if class_token else 0
+        self.no_embed_class = no_embed_class
+
+        patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
+                                    else {})
+        self.patch_embed = embed_layer(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            bias=not pre_norm,  # disable bias if pre-norm is used (e.g. CLIP)
+            **patch_embed_extra_kwargs
+        )
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+        embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
+        self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+
+        # We change the order of residual and layer norm:
+        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
+        # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
+        # the main branch (output of LN). The model definition is unchanged, but the mapping of the
+        # nn.LayerNorm weights are changed.
+        # This is for performance reason: we can fuse dropout + add + layer_norm.
+        # self.norm_0 is the first layer norm in the model, while self.norm
+        # (in the pretrained weight) is the final layer norm.
+        self.norm_0 = norm_layer(embed_dim)
+
+        self.blocks = nn.ModuleList([create_block(
+            embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
+            norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
+            fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
+            fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
+            last_layer_subset=(global_pool == 'token')
+        ) for i in range(depth)])
+
+        # Classifier Head
+        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        self.init_weights(weight_init)
+
+    def init_weights(self, mode=''):
+        assert mode == ''
+        trunc_normal_(self.pos_embed, std=.02)
+        if self.cls_token is not None:
+            nn.init.normal_(self.cls_token, std=1e-6)
+        named_apply(init_weights_vit_timm, self)
+
+    def _init_weights(self, m):
+        # this fn left here for compat with downstream users
+        init_weights_vit_timm(m)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def _pos_embed(self, x):
+        if self.no_embed_class:
+            # deit-3, updated JAX (big vision)
+            # position embedding does not overlap with class token, add then concat
+            x = x + self.pos_embed
+            if self.cls_token is not None:
+                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+        else:
+            # original timm, JAX, and deit vit impl
+            # pos_embed has entry for class token, concat then add
+            if self.cls_token is not None:
+                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+            x = x + self.pos_embed
+        return self.pos_drop(x)
+
+    def forward_features(self, x, all_tokens=True):
+        """
+        If all_tokens==False and self.global_pool == 'token', we only return the features for the
+        cls token.
+        """
+        x = self.patch_embed(x)
+        # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
+        residual = self._pos_embed(x).float()
+        hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
+        if self.global_pool != 'token' or all_tokens:
+            for block in self.blocks:
+                hidden_states, residual = block(hidden_states, residual)
+        else:
+            for block in self.blocks[:-1]:
+                hidden_states, residual = block(hidden_states, residual)
+            # For the last layer, we only want the 1st token of the output. So we do cross-attention
+            # where the query is the 1st token and the key/value is the whole sequence.
+            hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
+            residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
+            hidden_states, _ = self.blocks[-1](hidden_states_1st, residual_1st,
+                                               mixer_kwargs={'x_kv': hidden_states})
+        return hidden_states
+
+    def forward_head(self, x, pre_logits: bool = False):
+        if self.global_pool:
+            x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+        return x if pre_logits else self.head(x)
+
+    def forward(self, x):
+        x = self.forward_features(x, all_tokens=False)
+        x = self.forward_head(x)
+        return x
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ''):
+    """ ViT weight initialization, original timm impl (for reproducibility) """
+    if isinstance(module, nn.Linear):
+        trunc_normal_(module.weight, std=.02)
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+    elif hasattr(module, 'init_weights'):
+        module.init_weights()
+
+
+def vit_base_patch16_224(pretrained=False, **kwargs):
+    """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+    """
+    assert not pretrained
+    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+    model = VisionTransformer(**model_kwargs)
+    return model

+ 162 - 0
flash_attn/ops/triton/k_activations.py

@@ -0,0 +1,162 @@
+# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from enum import Enum
+from typing import Optional
+
+import triton
+import triton.language as tl
+
+
+_sqrt2pi = math.sqrt(2.0 / math.pi)
+_sqrt1_2 = math.sqrt(1.0 / 2)
+_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
+
+
+class Activation(str, Enum):
+    SquaredReLU = "squared_relu"
+    GeLU = "gelu"
+    GeLUApprox = "gelu_approx"
+    LeakyReLU = "leaky_relu"
+    ReLU = "relu"
+
+
+def get_triton_activation_kernel(activation: Optional[Activation]):
+    return (
+        {
+            Activation.ReLU: relu,
+            Activation.LeakyReLU: leaky_relu,
+            Activation.GeLU: gelu,
+            Activation.GeLUApprox: gelu_approx,
+            Activation.SquaredReLU: squared_relu,
+        }[activation]
+        if activation
+        else None
+    )
+
+
+def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
+    return (
+        {
+            Activation.ReLU: relu_grad,
+            Activation.LeakyReLU: leaky_relu_grad,
+            Activation.GeLU: gelu_grad,
+            Activation.GeLUApprox: gelu_approx_grad,
+            Activation.SquaredReLU: squared_relu_grad,
+        }[activation]
+        if activation
+        else None
+    )
+
+
+@triton.jit
+def tanh(x):
+    # Tanh is just a scaled sigmoid
+    return 2 * tl.sigmoid(2 * x) - 1
+
+
+@triton.jit
+def cosh(x):
+    exp_x = tl.exp(x)
+    return (exp_x + 1.0 / exp_x) * 0.5
+
+
+# a Triton implementation of the most used activations
+# See for instance http://arxiv.org/abs/1606.08415 for an overview
+
+# ReLU
+@triton.jit
+def relu(x):
+    """
+    ReLU_ activation function
+
+    .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
+    """
+    zero = 0.0
+    return tl.where(x >= 0, x, zero.to(x.dtype))
+
+
+@triton.jit
+def relu_grad(x):
+    # ReLU is different from other activations
+    # in that it does not require the input to retrospectively compute its gradient
+    # here the input is the downstream gradient, and we return the upstream gradient directly
+    zero = 0.0
+    one = 1.0
+    return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
+
+
+@triton.jit
+def squared_relu(x):
+    """
+    Squared ReLU activation, as proposed in the Primer_ paper.
+
+    .. _Primer: https://arxiv.org/abs/2109.08668
+    """
+    x_ = relu(x)
+    return (x_ * x_).to(x.dtype)
+
+
+@triton.jit
+def squared_relu_grad(x):
+    return tl.where(x >= 0, 2.0 * x, 0.0)
+
+
+# Leaky ReLU
+@triton.jit
+def leaky_relu(x):
+    """
+    LeakyReLU_ activation
+
+    .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
+    """
+    scale = 0.01 + 0.0
+    scale = scale.to(x.dtype)
+    return tl.where(x >= 0, x, scale * x)
+
+
+@triton.jit
+def leaky_relu_grad(x):
+    min_grad = 0.01
+    max_grad = 1
+
+    min_grad = min_grad.to(x.dtype)
+    max_grad = max_grad.to(x.dtype)
+
+    return tl.where(x >= 0, max_grad, min_grad)
+
+
+@triton.jit
+def gelu(x):
+    """Gaussian Error Linear Unit (GELU)"""
+    return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
+
+
+@triton.jit
+def gelu_grad(x):
+    cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
+    pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
+    return cdf + x * pdf
+
+@triton.jit
+def gelu_approx(x):
+    """
+    GeLU_ activation - Gaussian error linear unit, with tanh approximation
+
+    .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
+    """
+    return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
+
+
+@triton.jit
+def gelu_approx_grad(x):
+    # CREDITS: Fast implementation proposed in
+    # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
+    tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+    return 0.5 * x * (
+        (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
+    ) + 0.5 * (1 + tanh_out)

+ 479 - 0
flash_attn/ops/triton/linear.py

@@ -0,0 +1,479 @@
+# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
+# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+from torch.autograd.function import FunctionCtx
+from torch.cuda.amp import custom_fwd
+from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
+
+from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
+
+
+# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
+
+
+def init_to_zero(name):
+    return lambda nargs: nargs[name].zero_()
+
+
+def get_configs_io_bound():
+    configs = []
+    for num_stages in [2, 3, 4, 5, 6]:
+        for block_m in [16, 32]:
+            for block_k in [32, 64]:
+                for block_n in [32, 64, 128, 256]:
+                    num_warps = 2 if block_n <= 64 else 4
+                    configs.append(
+                        triton.Config(
+                            {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
+                            num_stages=num_stages,
+                            num_warps=num_warps,
+                        )
+                    )
+                    # split_k not used
+                    # for split_k in [2, 4, 8, 16]:
+                    #     configs.append(triton.Config(
+                    #         {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
+                    #         num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
+    return configs
+
+
+@triton.autotune(
+    configs=[
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
+        # good for int8
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
+    ]
+    + get_configs_io_bound(),
+    key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
+    prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
+)
+@triton.heuristics(
+    {
+        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
+    }
+)
+@triton.jit
+def kernel_fwd(
+    C,  # Pointers to matrices
+    ACT_INPUT,
+    A,
+    B,
+    bias,
+    # Matrix dimensions
+    M,
+    N,
+    K,
+    CACHE_KEY_M,
+    CACHE_KEY_N,
+    CACHE_KEY_K,
+    # The stride variables represent how much to increase the ptr by when moving by 1
+    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
+    # by to get the element one row down (A has M rows)
+    stride_cm,
+    # stride_cn,  # Assume that stride_cn == 1
+    stride_am,
+    stride_ak,
+    stride_bn,
+    stride_bk,
+    # Meta-parameters
+    BLOCK_M: tl.constexpr,
+    GROUP_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_K: tl.constexpr,
+    # split k not used, not performant with activation, kept because early_config_prune is expecting it
+    SPLIT_K: tl.constexpr,
+    EVEN_K: tl.constexpr,
+    A_ROWMAJOR: tl.constexpr,
+    B_COLMAJOR: tl.constexpr,
+    BIAS: tl.constexpr,
+    SAVE_ACT_INPUT: tl.constexpr,
+    ACTIVATION: tl.constexpr,
+):
+
+    """
+    Kernel for computing Out = activation(A x W + C)
+    - Input has shape (M, K)
+    - Weight has shape (K, N)
+    - Bias has shape (N,)
+    - Output has shape (M, N)
+    - ActInputs (optional) has shape (M, N)
+    'ActInputs' optionally saves the A x W + C intermediate for backward computations
+    This kernel will consolidate over K
+    """
+
+    pid = tl.program_id(axis=0)
+
+    grid_m = (M + BLOCK_M - 1) // BLOCK_M
+    grid_n = (N + BLOCK_N - 1) // BLOCK_N
+    # re-order program ID for better L2 performance
+    width = GROUP_M * grid_n
+    group_id = pid // width
+    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+    pid_m = group_id * GROUP_M + (pid % group_size)
+    pid_n = (pid % width) // (group_size)
+
+    # now compute the block that each program will go through
+    # rm (resp. rn) denotes a range of indices
+    # for rows (resp. col) of C
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    # trick to avoid masking on M and N axis
+    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+    rk = tl.arange(0, BLOCK_K)
+
+    if A_ROWMAJOR:
+        A = A + (ram[:, None] * stride_am + rk[None, :])
+    else:
+        A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+    if B_COLMAJOR:
+        B = B + (rk[:, None] + rbn[None, :] * stride_bn)
+    else:
+        B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+    for k in range(K, 0, -BLOCK_K):
+        if EVEN_K:
+            a = tl.load(A)
+            b = tl.load(B)
+        else:
+            a = tl.load(A, mask=rk[None, :] < k, other=0.0)
+            b = tl.load(B, mask=rk[:, None] < k, other=0.0)
+        acc += tl.dot(a, b)
+
+        if A_ROWMAJOR:
+            A += BLOCK_K
+        else:
+            A += BLOCK_K * stride_ak
+        if B_COLMAJOR:
+            B += BLOCK_K
+        else:
+            B += BLOCK_K * stride_bk
+
+    # Putting bias after the matmul (instead of before) is faster, idk why
+    if BIAS:
+        bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
+        acc += bias[None, :]
+
+    # optional: save the activation inputs
+    if SAVE_ACT_INPUT:
+        # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
+        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
+        tl.store(act_in_ptrs, acc)
+
+    # optional: fused activation (while the data is in shared memory)
+    if ACTIVATION == "gelu":
+        acc = gelu(acc)
+    elif ACTIVATION == "gelu_approx":
+        acc = gelu_approx(acc)
+    elif ACTIVATION == "squared_relu":
+        acc = squared_relu(acc)
+    # rematerialize rm and rn to save registers
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+    # write back result
+    # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
+    C = C + rm[:, None] * stride_cm + rn[None, :]
+    mask = (rm < M)[:, None] & (rn < N)[None, :]
+    tl.store(C, acc)
+
+
+def triton_linear_act(
+    x: torch.Tensor,
+    weight: torch.Tensor,
+    bias: Optional[torch.Tensor] = None,
+    activation: str = 'id',
+    save_act_input: bool = False,
+) -> torch.Tensor:
+    """
+    Compute e = activation(x @ weight.T + bias).
+    This wrapper kicks the `kernel_fwd` Triton kernel
+    :param x: input tensor
+    :param weight: weight matrix
+    :param bias: an optional bias tensor
+    :param activation: Activation name. Needs to be a Triton kernel.
+    :param act_input: an optional tensor to save the activation inputs (for backward)
+    :return: result tensor
+    """
+    # if torch.is_autocast_enabled():
+    #     dtype = torch.get_autocast_gpu_dtype()
+    #     x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
+
+    assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
+
+    batch_shape, n = x.shape[:-1], x.shape[-1]
+    batch_dim = batch_shape.numel()
+    x_reshaped = x.reshape(batch_dim, n)
+
+    if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
+        x_reshaped = x_reshaped.contiguous()
+    if weight.stride(0) > 1 and weight.stride(1) > 1:
+        weight = weight.contiguous()
+    bias = bias.contiguous() if bias is not None else None
+
+    assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
+    if bias is not None:
+        assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
+    assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
+
+    assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
+
+    M, K = x_reshaped.shape
+    N, K = weight.shape
+
+    output = torch.empty((M, N), device=x.device, dtype=x.dtype)
+    act_input = torch.empty_like(output) if save_act_input else None
+
+    # 1D launch kernel where each block gets its own program.
+    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)  # noqa
+
+    kernel_fwd[grid](
+        output,
+        act_input,
+        x_reshaped,
+        weight,  # data ptrs
+        bias if bias is not None else x,  # auto skip bias if not present
+        M,  # shapes
+        N,
+        K,
+        M // 32,  # key for triton cache (limit number of compilations)
+        N // 32,
+        K // 32,
+        stride_cm=output.stride(0),  # strides
+        # stride_cn=output.stride(1),
+        stride_am=x_reshaped.stride(0),
+        stride_ak=x_reshaped.stride(1),
+        stride_bk=weight.stride(1),
+        stride_bn=weight.stride(0),
+        BIAS=bias is not None,  # optional fused bias
+        SAVE_ACT_INPUT=save_act_input,  # optional save activation inputs
+        ACTIVATION=activation,  # optional fused activation
+        A_ROWMAJOR=x_reshaped.stride(1) == 1,
+        B_COLMAJOR=weight.stride(1) == 1,
+        GROUP_M=8,  # speed optimization: group the programs
+    )
+
+    if not save_act_input:
+        return output.reshape(*batch_shape, output.shape[-1])
+    else:
+        return (output.reshape(*batch_shape, output.shape[-1]),
+                act_input.reshape(*batch_shape, act_input.shape[-1]))
+
+
+@triton.autotune(
+    configs=[
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
+        # good for int8
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
+        triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
+        triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
+    ]
+    + get_configs_io_bound(),
+    key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
+    prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
+)
+@triton.heuristics(
+    {
+        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
+    }
+)
+@triton.jit
+def kernel_bwd(
+    C,  # Pointers to matrices
+    ACT_INPUT,
+    A,
+    B,
+    # Matrix dimensions
+    M,
+    N,
+    K,
+    CACHE_KEY_M,
+    CACHE_KEY_N,
+    CACHE_KEY_K,
+    # The stride variables represent how much to increase the ptr by when moving by 1
+    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
+    # by to get the element one row down (A has M rows)
+    stride_cm,
+    # stride_cn,  # Assume that stride_cn == 1
+    stride_am,
+    stride_ak,
+    stride_bk,
+    stride_bn,
+    # Meta-parameters
+    BLOCK_M: tl.constexpr,
+    GROUP_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_K: tl.constexpr,
+    # split k not used, not performant with activation, kept because early_config_prune is expecting it
+    SPLIT_K: tl.constexpr,
+    EVEN_K: tl.constexpr,
+    ACTIVATION: tl.constexpr,
+):
+
+    """
+    Kernel for computing Out = activation(A x W + C)
+    - Input has shape (M, K)
+    - Weight has shape (K, N)
+    - Output has shape (M, N)
+    - ActInputs (optional) has shape (M, N)
+    'ActInputs' optionally saves the A x W + C intermediate for backward computations
+    This kernel will consolidate over K
+    """
+
+    pid = tl.program_id(axis=0)
+
+    grid_m = (M + BLOCK_M - 1) // BLOCK_M
+    grid_n = (N + BLOCK_N - 1) // BLOCK_N
+    # re-order program ID for better L2 performance
+    width = GROUP_M * grid_n
+    group_id = pid // width
+    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+    pid_m = group_id * GROUP_M + (pid % group_size)
+    pid_n = (pid % width) // (group_size)
+
+    # now compute the block that each program will go through
+    # rm (resp. rn) denotes a range of indices
+    # for rows (resp. col) of C
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    # trick to avoid masking on M and N axis
+    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+    rk = tl.arange(0, BLOCK_K)
+
+    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+
+    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+    for k in range(K, 0, -BLOCK_K):
+        if EVEN_K:
+            a = tl.load(A)
+            b = tl.load(B)
+        else:
+            a = tl.load(A, mask=rk[None, :] < k, other=0.0)
+            b = tl.load(B, mask=rk[:, None] < k, other=0.0)
+        acc += tl.dot(a, b)
+
+        A += BLOCK_K * stride_ak
+        B += BLOCK_K * stride_bk
+
+    # optional: fused activation (while the data is in shared memory)
+    if ACTIVATION != 'id':
+        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
+        act_input = tl.load(act_in_ptrs).to(acc.dtype)
+    if ACTIVATION == "gelu":
+        acc *= gelu_grad(act_input)
+    elif ACTIVATION == "gelu_approx":
+        acc *= gelu_approx_grad(act_input)
+    elif ACTIVATION == "squared_relu":
+        acc *= squared_relu_grad(act_input)
+
+    # rematerialize rm and rn to save registers
+    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+    # write back result
+    C = C + rm[:, None] * stride_cm + rn[None, :]
+    mask = (rm < M)[:, None] & (rn < N)[None, :]
+    tl.store(C, acc, mask=mask)
+
+
+def triton_dgrad_act(
+    grad_output: torch.Tensor,
+    weight: torch.Tensor,
+    activation: str = 'id',
+    act_input: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+    """
+    Compute e = activation(grad_output @ weight + bias).
+    This wrapper kicks the `kernel_fwd` Triton kernel
+    :param grad_output: input tensor
+    :param weight: weight matrix
+    :param activation: Activation name. Needs to be a Triton kernel.
+    :param act_input: an optional tensor to save the activation inputs (for backward)
+    :return: result tensor
+    """
+    assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
+
+    batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
+    batch_dim = batch_shape.numel()
+    grad_output_reshaped = grad_output.reshape(batch_dim, n)
+
+    if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
+        grad_output_reshaped = grad_output_reshaped.contiguous()
+    if weight.stride(0) > 1 and weight.stride(1) > 1:
+        weight = weight.contiguous()
+
+    assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
+    assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
+    if activation != 'id':
+        assert act_input is not None, f'act_input is required for activation {activation}'
+
+    # M, N, K in bwd are different from M, N, K in fwd
+    M, K = grad_output_reshaped.shape
+    K, N = weight.shape
+
+    grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)
+
+    # 1D launch kernel where each block gets its own program.
+    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)  # noqa
+
+    kernel_bwd[grid](
+        grad_input,
+        act_input,
+        grad_output_reshaped,
+        weight,  # data ptrs
+        M,  # shapes
+        N,
+        K,
+        M // 32,  # key for triton cache (limit number of compilations)
+        N // 32,
+        K // 32,
+        stride_cm=grad_input.stride(0),  # strides
+        # stride_cn=grad_input.stride(1),
+        stride_am=grad_output_reshaped.stride(0),
+        stride_ak=grad_output_reshaped.stride(1),
+        stride_bk=weight.stride(0),
+        stride_bn=weight.stride(1),
+        ACTIVATION=activation,  # optional fused activation
+        GROUP_M=8,  # speed optimization: group the programs
+    )
+
+    return grad_input.reshape(*batch_shape, grad_input.shape[-1])

+ 140 - 0
flash_attn/ops/triton/mlp.py

@@ -0,0 +1,140 @@
+# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
+# to naive implementation.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+import fused_dense_lib as fused_dense_cuda
+
+from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
+
+
+@torch.jit.script
+def sqrelu_fwd(x):
+    r = F.relu(x)
+    return (r * r).to(dtype=x.dtype)
+
+
+@torch.jit.script
+def sqrelu_bwd(g, x):
+    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
+
+
+class FusedDenseSqreluDenseFunc(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
+        """checkpoint_lvl:
+        0: no recomputation in the bwd
+        1: recompute gelu_out in the bwd
+        2: recompute act_input and gelu_out in the bwd
+        """
+        if torch.is_autocast_enabled():
+            dtype = torch.get_autocast_gpu_dtype()
+            x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
+                                                 for a in [x, weight1, bias1, weight2, bias2]]
+        is_bf16 = x.dtype == torch.bfloat16
+        assert checkpoint_lvl in [0, 1, 2]
+        x = x.contiguous()
+        weight1 = weight1.contiguous()
+        bias1 = bias1.contiguous()
+        weight2 = weight2.contiguous()
+        bias2 = bias2.contiguous()
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        if is_bf16:
+            act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+            output1 = sqrelu_fwd(act_input)
+        else:
+            save_act_input = checkpoint_lvl != 2
+            result = triton_linear_act(
+                x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
+                save_act_input=save_act_input
+            )
+            if save_act_input:
+                output1, act_input = result
+            else:
+                output1 = result
+        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
+        ctx.checkpoint_lvl = checkpoint_lvl
+        if checkpoint_lvl == 0:
+            ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1)
+        elif checkpoint_lvl == 1:
+            ctx.save_for_backward(x, weight1, bias1, weight2, act_input)
+        elif checkpoint_lvl == 2:
+            ctx.save_for_backward(x, weight1, bias1, weight2)
+        return output2.reshape(*batch_shape, output2.shape[-1])
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output):
+        grad_output = grad_output.contiguous()
+        checkpoint_lvl = ctx.checkpoint_lvl
+        x, weight1, bias1, weight2, *rest = ctx.saved_tensors
+        batch_shape, n = x.shape[:-1], x.shape[-1]
+        batch_dim = batch_shape.numel()
+        is_bf16 = x.dtype == torch.bfloat16
+        if checkpoint_lvl == 0:
+            act_input, output1 = rest
+        elif checkpoint_lvl == 1:
+            act_input, = rest
+            output1 = sqrelu_fwd(act_input)
+        elif checkpoint_lvl == 2:
+            if is_bf16:
+                act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
+                output1 = sqrelu_fwd(act_input)
+            else:
+                output1, act_input = triton_linear_act(
+                    x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
+                    save_act_input=True
+                )
+
+        if is_bf16:
+            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
+            grad_output1 = grad_output @ weight2
+            grad_act_input = sqrelu_bwd(grad_output1, act_input)
+            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
+                x.reshape(batch_dim, n), weight1, grad_act_input
+            )
+        else:
+            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
+            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
+            grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu',
+                                              act_input=act_input)
+            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
+                x.reshape(batch_dim, n), weight1, grad_act_input
+            )
+        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None
+
+
+fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
+
+
+class FusedDenseSqreluDense(nn.Module):
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
+                 checkpoint_lvl=0, device=None, dtype=None):
+        """
+        checkpoint_lvl (increasing lvl means slower but more memory saving):
+            0: no recomputation in the bwd
+            1: recompute gelu_out in the bwd
+            2: recompute gelu_in and gelu_out in the bwd
+        """
+        assert checkpoint_lvl in [0, 1, 2]
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        assert bias == True, "DenseSqreluDense module without bias is currently not supported"
+        self.checkpoint_lvl = checkpoint_lvl
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
+
+    def forward(self, x):
+        assert x.is_cuda
+        return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias,
+                                                 self.fc2.weight, self.fc2.bias,
+                                                 self.checkpoint_lvl)