Michael Melesse b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
..
README.md b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
__init__.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
bench.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
bwd_prefill.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
bwd_ref.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
fwd_decode.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
fwd_prefill.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
fwd_ref.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
interface_fa.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
interface_torch.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
test.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete
utils.py b518517cb8 [AMD] Triton Backend for ROCm (#1203) 1 hete

README.md

Flash Attention Triton Kernel

Introduction

The Triton implementation of the Flash Attention v2 is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd 1) Fwd and Bwd with causal masking 2) Variable sequence lengths 3) Arbitrary Q and KV sequence lengths 4) Arbitrary head sizes

These features are supported in Fwd for now. We will add them to backward soon. 1) Multi and grouped query attention 2) ALiBi and matrix bias

These features are in development 1) Paged Attention 2) Sliding Window 3) Rotary embeddings 4) Dropout 5) Performance Improvements

Getting Started

To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton commit.

git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 
pip install --verbose -e python

Then install and test Flash Attention with the flag FLASH_ATTENTION_TRITON_AMD_ENABLE set to "TRUE".

export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py

Credits

AMD Triton kernels team

OpenAI kernel team