1
0

usage.md 6.9 KB

FlashAttention adoption

We've been very happy to see FlashAttention being adopted by many organizations and research labs to speed up their training / inference (within 6 months after FlashAttention's release, at the time of writing). This page contains a partial list of places where FlashAttention is being used. If you'd like to add links to your organization / product / codebase, please open a PR or email us. We'd very much like to hear from you!

Integrated into machine learning frameworks

  • Pytorch: integrated into core Pytorch in nn.Transformer.

  • Huggingface's transformers library. On-going, blogpost coming soon.

  • Microsoft's DeepSpeed: FlashAttention is integrated into DeepSpeed's inference engine.

  • Nvidia's Megatron-LM. This library is a popular framework on training large transformer language models at scale.

  • MosaicML Composer library. Composer is a library for efficient neural network training.

  • EleutherAI's GPT-NeoX. This is a research library for training large language transformer models at scale based on NVIDIA's Megatron-LM and Microsoft's DeepSpeed.

  • PaddlePaddle: integrated into the framework with API paddle.nn.functional.flash_attention.

MLPerf benchmarks

MLPerf is a competitive machine learning performance benchmark. FlashAttention yields the fastest BERT training on cloud instances in MLPerf training 2.0 (June 2022) and MLPerf training 2.1 (November 2022).

  • MLPerf 2.0: IEEE Spectrum and Forbes articles about our submission to the MLPerf 2.0 benchmark using FlashAttention.

  • MLPerf 2.1 - collaboration between Azure and Hazy Research: for the first time, we can train MLPerf BERT in under 2 minutes on 16 nodes.

  • MLPerf 2.1 - Nvidia: Nvidia uses techniques from FlashAttention to make their (already extremely optimized) BERT implementation go even faster.

  • MLPerf 2.1 - MosaicML: FlashAttention helps train BERT 2.7x faster in the open division.

Language model training & inference

  • PubMedGPT 2.7B, a domain-specific LLM for biomedicine, by Stanford CRFM, trained on MosaicML Cloud. Just using FlashAttention nearly halves the total training time.

  • Meta's AITemplate uses FlashAttention as part of their approach to speed up Transformer inference (up to 5.3x on BERT).

  • Nvidia's FasterTransformer is a state-of-the-art Transformer inference library. As of version 5.2, FlashAttention is used as a component of FasterTransformer to speed up GPT inference.

  • Kernl is a library for fast Transformer inference. They use FlashAttention as part of their approach to speed up Transformers by up to 12x.

Diffusion model training and inference

  • Huggingface's diffusers library for diffusion models. FlashAttention is integrated into diffusers v0.7.0. Up to 2x faster inference and lower memory usage.

  • Colossal-AI's implementation of Stable Diffusion: with FlashAttention as one of its components, it speeds up pretraining by up to 6.5x, and reduces the hardware cost of fine-tuning by 7x.

  • Meta's AITemplate with FlashAttention one of the components, is currently the fastest Stable Diffusion inference engine that we know of.

  • Stable Diffusion inference from Labml.ai: 50% speedup.

  • Our own Stable Diffusion fork uses FlashAttention to get 3-4x speedup compared to the original version.

Other models

  • Uni-Fold: Uni-Fold is an open-source platform for developing protein models beyond AlphaFold. With FlashAttention, Uni-Fold is 2.6x faster than AlphaFold.

  • OpenFold: a trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2. With FlashAttention as one of its components, it is up to 3x faster than AlphaFold2 to run inference on short sequences, and can predict 2x longer structures.

Different implementations

  • Triton: an implementation of FlashAttention in Triton by Phil Tillet from OpenAI. Triton is a Python-based language and compiler for parallel programming.

  • xformers: The xformers team has implemented memory-efficient attention in a similar spirit to FlashAttention. xformers dynamically dispatches to whichever implementation is available / faster.

  • Jax: an implementation in Jax by lucidrains.

  • Metal: an implementation in Metal by Philip Turner. This ports FlashAttention to mobile GPU architectures such as Apple silicon.