Michael Melesse b518517cb8 [AMD] Triton Backend for ROCm (#1203) | 1 settimana fa | |
---|---|---|
.. | ||
README.md | 1 settimana fa | |
__init__.py | 1 settimana fa | |
bench.py | 1 settimana fa | |
bwd_prefill.py | 1 settimana fa | |
bwd_ref.py | 1 settimana fa | |
fwd_decode.py | 1 settimana fa | |
fwd_prefill.py | 1 settimana fa | |
fwd_ref.py | 1 settimana fa | |
interface_fa.py | 1 settimana fa | |
interface_torch.py | 1 settimana fa | |
test.py | 1 settimana fa | |
utils.py | 1 settimana fa |
The Triton implementation of the Flash Attention v2 is currently a work in progress.
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
These features are supported in Fwd and Bwd 1) Fwd and Bwd with causal masking 2) Variable sequence lengths 3) Arbitrary Q and KV sequence lengths 4) Arbitrary head sizes
These features are supported in Fwd for now. We will add them to backward soon. 1) Multi and grouped query attention 2) ALiBi and matrix bias
These features are in development 1) Paged Attention 2) Sliding Window 3) Rotary embeddings 4) Dropout 5) Performance Improvements
To get started with the triton backend for AMD, follow the steps below.
First install the recommended Triton commit.
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install --verbose -e python
Then install and test Flash Attention with the flag FLASH_ATTENTION_TRITON_AMD_ENABLE
set to "TRUE"
.
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
AMD Triton kernels team
OpenAI kernel team