瀏覽代碼

Merge pull request #94 from calebthomas259/main

Add a simple tutorial to README.md
Tri Dao 2 年之前
父節點
當前提交
57ee618170
共有 1 個文件被更改,包括 63 次插入0 次删除
  1. 63 0
      README.md

+ 63 - 0
README.md

@@ -75,6 +75,69 @@ Our tentative roadmap:
 10. [Jun 2023] Support SM70 GPUs (V100).
 11. [Jun 2023] Support SM90 GPUs (H100).
 
+
+## How to use FlashAttention
+
+Here's a simple example:
+```python
+import torch
+from flash_attn.flash_attention import FlashMHA
+
+# Replace this with your correct GPU device
+device = "cuda:0"
+
+# Create attention layer. This is similar to torch.nn.MultiheadAttention,
+# and it includes the input and output linear layers
+flash_mha = FlashMHA(
+    embed_dim=128, # total channels (= num_heads * head_dim)
+    num_heads=8, # number of heads
+    device=device,
+    dtype=torch.float16,
+)
+
+# Run forward pass with dummy data
+x = torch.randn(
+    (64, 256, 128), # (batch, seqlen, embed_dim)
+    device=device,
+    dtype=torch.float16
+)
+
+output = flash_mha(x)[0]
+```
+
+Alternatively, you can import the inner attention layer only (so that the input
+and output linear layers are not included):
+```python
+from flash_attn.flash_attention import FlashAttention
+
+# Create the nn.Module
+flash_attention = FlashAttention()
+```
+
+Or, if you need more fine-grained control, you can import one of the lower-level
+functions (this is more similar to the `torch.nn.functional` style):
+```python
+from flash_attn.flash_attn_interface import flash_attn_unpadded_func
+
+# or
+
+from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
+
+# etc.
+```
+
+There are also separate Python files with various FlashAttention extensions:
+```python
+# Import the triton implementation (torch.nn.functional version only)
+from flash_attn.flash_attn_triton import flash_attn_func
+
+# Import block sparse attention (nn.Module version)
+from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
+
+# Import block sparse attention (torch.nn.functional version)
+from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
+```
+
 ## Speedup and Memory Savings
 
 We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).