Browse Source

Don't support softcap and dropout at the same time

These tests are failing so I'm just disabling this case for now
Tri Dao 8 months ago
parent
commit
dca6d89da4

+ 4 - 0
csrc/flash_attn/flash_api.cpp

@@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
+    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
+
     if (window_size_left >= seqlen_k) { window_size_left = -1; }
     if (window_size_right >= seqlen_k) { window_size_right = -1; }
 
@@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     const int head_size_og = sizes[2];
     const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
 
+    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
+
     const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
     const int num_blocks = !paged_KV ? 0 : k.size(0);
     const int page_block_size = !paged_KV ? 1 : k.size(1);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                             // If return_softmax, set IsEvenMNConst to false to reduce number of templates
                             // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                             // If Is_local, set Is_causal to false
-                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
                             // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
                             // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
                             // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;

+ 8 - 4
tests/test_flash_attn.py

@@ -895,12 +895,14 @@ def test_flash_attn_output(
         and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
     ):
         pytest.skip()  # Reference implementation OOM
+    if softcap > 0.0 and dropout_p > 0.0:
+        pytest.skip("Softcap and dropout not supported together")
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
     batch_size = 4
-    nheads = 9
-    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
     assert nheads % nheads_k == 0
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
@@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
         and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
     ):
         pytest.skip()  # Reference implementation OOM
+    if softcap > 0.0 and dropout_p > 0.0:
+        pytest.skip("Softcap and dropout not supported together")
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
     batch_size = 4
-    nheads = 9
-    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
     assert nheads % nheads_k == 0
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)