|
@@ -75,6 +75,69 @@ Our tentative roadmap:
|
|
|
10. [Jun 2023] Support SM70 GPUs (V100).
|
|
|
11. [Jun 2023] Support SM90 GPUs (H100).
|
|
|
|
|
|
+
|
|
|
+## How to use FlashAttention
|
|
|
+
|
|
|
+Here's a simple example:
|
|
|
+```python
|
|
|
+import torch
|
|
|
+from flash_attn.flash_attention import FlashMHA
|
|
|
+
|
|
|
+# Replace this with your correct GPU device
|
|
|
+device = "cuda:0"
|
|
|
+
|
|
|
+# Create attention layer. This is similar to torch.nn.MultiheadAttention,
|
|
|
+# and it includes the input and output linear layers
|
|
|
+flash_mha = FlashMHA(
|
|
|
+ embed_dim=128, # total channels (= num_heads * head_dim)
|
|
|
+ num_heads=8, # number of heads
|
|
|
+ device=device,
|
|
|
+ dtype=torch.float16,
|
|
|
+)
|
|
|
+
|
|
|
+# Run forward pass with dummy data
|
|
|
+x = torch.randn(
|
|
|
+ (64, 256, 128), # (batch, seqlen, embed_dim)
|
|
|
+ device=device,
|
|
|
+ dtype=torch.float16
|
|
|
+)
|
|
|
+
|
|
|
+output = flash_mha(x)[0]
|
|
|
+```
|
|
|
+
|
|
|
+Alternatively, you can import the inner attention layer only (so that the input
|
|
|
+and output linear layers are not included):
|
|
|
+```python
|
|
|
+from flash_attn.flash_attention import FlashAttention
|
|
|
+
|
|
|
+# Create the nn.Module
|
|
|
+flash_attention = FlashAttention()
|
|
|
+```
|
|
|
+
|
|
|
+Or, if you need more fine-grained control, you can import one of the lower-level
|
|
|
+functions (this is more similar to the `torch.nn.functional` style):
|
|
|
+```python
|
|
|
+from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
|
|
+
|
|
|
+# or
|
|
|
+
|
|
|
+from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
|
|
|
+
|
|
|
+# etc.
|
|
|
+```
|
|
|
+
|
|
|
+There are also separate Python files with various FlashAttention extensions:
|
|
|
+```python
|
|
|
+# Import the triton implementation (torch.nn.functional version only)
|
|
|
+from flash_attn.flash_attn_triton import flash_attn_func
|
|
|
+
|
|
|
+# Import block sparse attention (nn.Module version)
|
|
|
+from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
|
|
|
+
|
|
|
+# Import block sparse attention (torch.nn.functional version)
|
|
|
+from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
|
|
|
+```
|
|
|
+
|
|
|
## Speedup and Memory Savings
|
|
|
|
|
|
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|