|
@@ -895,12 +895,14 @@ def test_flash_attn_output(
|
|
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
|
|
):
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
+ if softcap > 0.0 and dropout_p > 0.0:
|
|
|
+ pytest.skip("Softcap and dropout not supported together")
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
batch_size = 4
|
|
|
- nheads = 9
|
|
|
- nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
+ nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
|
|
|
+ nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
|
|
|
assert nheads % nheads_k == 0
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
@@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
|
|
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
|
|
):
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
+ if softcap > 0.0 and dropout_p > 0.0:
|
|
|
+ pytest.skip("Softcap and dropout not supported together")
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
batch_size = 4
|
|
|
- nheads = 9
|
|
|
- nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
+ nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
|
|
|
+ nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
|
|
|
assert nheads % nheads_k == 0
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|