|
@@ -566,10 +566,12 @@ def get_dropout_fraction(
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
+@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
+# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
-# @pytest.mark.parametrize("alibi", [True])
|
|
|
+# @pytest.mark.parametrize("alibi", [False])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
-# @pytest.mark.parametrize("local", [True])
|
|
|
+# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize("causal", [False])
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
@@ -578,16 +580,16 @@ def get_dropout_fraction(
|
|
|
# @pytest.mark.parametrize("d", [64])
|
|
|
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
|
|
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
|
|
|
-# @pytest.mark.parametrize("seqlen", [97])
|
|
|
+# @pytest.mark.parametrize("seqlen", [512])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize("dropout_p", [0.0])
|
|
|
-def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
|
|
|
+def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
|
|
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 8
|
|
|
+ batch_size = 4
|
|
|
nheads = 9
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
|
|
|
qkv = torch.randn(
|
|
@@ -604,6 +606,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
if dropout_p > 0.0:
|
|
@@ -712,6 +715,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize('dtype', [torch.float16])
|
|
|
+@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
+# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
@@ -725,7 +730,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
|
|
|
# @pytest.mark.parametrize('seqlen', [128])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize('dropout_p', [0.0])
|
|
|
-def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
|
|
|
+def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
|
|
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
device = "cuda"
|
|
@@ -760,6 +765,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
out = output_pad_fn(out_unpad)
|
|
@@ -859,6 +865,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
|
|
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
+@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
+# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
@@ -890,7 +898,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize("dropout_p", [0.17])
|
|
|
def test_flash_attn_output(
|
|
|
- seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
|
|
|
+ seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
|
|
|
):
|
|
|
if (
|
|
|
max(seqlen_q, seqlen_k) >= 2048
|
|
@@ -900,7 +908,7 @@ def test_flash_attn_output(
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 8
|
|
|
+ batch_size = 4
|
|
|
nheads = 9
|
|
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
assert nheads % nheads_k == 0
|
|
@@ -931,6 +939,7 @@ def test_flash_attn_output(
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
else:
|
|
@@ -942,6 +951,7 @@ def test_flash_attn_output(
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
if dropout_p > 0.0:
|
|
@@ -1114,6 +1124,8 @@ def test_flash_attn_output(
|
|
|
# @pytest.mark.parametrize('dtype', [torch.float16])
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
# @pytest.mark.parametrize('mha_type', ["mqa"])
|
|
|
+@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
+# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
@@ -1143,7 +1155,7 @@ def test_flash_attn_output(
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize('dropout_p', [0.0])
|
|
|
def test_flash_attn_varlen_output(
|
|
|
- seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
|
|
|
+ seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
|
|
|
):
|
|
|
if (
|
|
|
max(seqlen_q, seqlen_k) >= 2048
|
|
@@ -1153,7 +1165,7 @@ def test_flash_attn_varlen_output(
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 8
|
|
|
+ batch_size = 4
|
|
|
nheads = 9
|
|
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
assert nheads % nheads_k == 0
|
|
@@ -1207,6 +1219,7 @@ def test_flash_attn_varlen_output(
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
else:
|
|
@@ -1237,6 +1250,7 @@ def test_flash_attn_varlen_output(
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
out = output_pad_fn(out_unpad)
|
|
@@ -1675,6 +1689,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
+@pytest.mark.parametrize("deterministic", [False, True])
|
|
|
+# @pytest.mark.parametrize("deterministic", [True])
|
|
|
@pytest.mark.parametrize("alibi", [False, True])
|
|
|
# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
@@ -1704,7 +1720,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
],
|
|
|
)
|
|
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
|
|
-def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype):
|
|
|
+def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
|
|
|
if swap_sq_sk:
|
|
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
|
|
device = "cuda"
|
|
@@ -1729,6 +1745,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
+ deterministic=deterministic,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
out_ref, attn_ref = attention_ref(
|
|
@@ -2224,3 +2241,152 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
|
|
|
assert not q.grad.isnan().any()
|
|
|
assert not k.grad.isnan().any()
|
|
|
assert not v.grad.isnan().any()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
+# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
+@pytest.mark.parametrize("local", [False, True])
|
|
|
+# @pytest.mark.parametrize("local", [True])
|
|
|
+@pytest.mark.parametrize("causal", [False, True])
|
|
|
+# @pytest.mark.parametrize("causal", [True])
|
|
|
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
|
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
|
|
+# @pytest.mark.parametrize('d', [56, 80])
|
|
|
+# @pytest.mark.parametrize("d", [64])
|
|
|
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
|
|
+# @pytest.mark.parametrize("swap_sq_sk", [False])
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "seqlen_q,seqlen_k",
|
|
|
+ [
|
|
|
+ (1, 239),
|
|
|
+ (3, 799),
|
|
|
+ (127, 512),
|
|
|
+ (127, 513),
|
|
|
+ (113, 203),
|
|
|
+ (128, 217),
|
|
|
+ (113, 211),
|
|
|
+ (108, 256),
|
|
|
+ (256, 512),
|
|
|
+ (1023, 1024),
|
|
|
+ ],
|
|
|
+)
|
|
|
+# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
|
|
+def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
|
|
|
+ if (
|
|
|
+ max(seqlen_q, seqlen_k) >= 2048
|
|
|
+ and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
|
|
+ ):
|
|
|
+ pytest.skip() # Reference implementation OOM
|
|
|
+ if swap_sq_sk:
|
|
|
+ seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
|
|
+ device = "cuda"
|
|
|
+ # set seed
|
|
|
+ torch.random.manual_seed(0)
|
|
|
+ batch_size = 4
|
|
|
+ nheads = 9
|
|
|
+ window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
+ q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
|
|
|
+
|
|
|
+ g = torch.randn_like(out)
|
|
|
+ if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
+ for _ in range(50):
|
|
|
+ dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
|
|
|
+ assert torch.equal(dv, dv0)
|
|
|
+ assert torch.equal(dk, dk0)
|
|
|
+ assert torch.equal(dq, dq0)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
+# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
+@pytest.mark.parametrize("local", [False, True])
|
|
|
+# @pytest.mark.parametrize("local", [True])
|
|
|
+@pytest.mark.parametrize("causal", [False, True])
|
|
|
+# @pytest.mark.parametrize("causal", [True])
|
|
|
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
|
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
|
|
+# @pytest.mark.parametrize('d', [56, 80])
|
|
|
+# @pytest.mark.parametrize("d", [64])
|
|
|
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
|
|
+# @pytest.mark.parametrize("swap_sq_sk", [True])
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "seqlen_q,seqlen_k",
|
|
|
+ [
|
|
|
+ (1, 239),
|
|
|
+ (3, 799),
|
|
|
+ (127, 512),
|
|
|
+ (127, 513),
|
|
|
+ (113, 203),
|
|
|
+ (128, 217),
|
|
|
+ (113, 211),
|
|
|
+ (108, 256),
|
|
|
+ (256, 512),
|
|
|
+ (1023, 1024),
|
|
|
+ ],
|
|
|
+)
|
|
|
+# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
|
|
|
+def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
|
|
|
+ if (
|
|
|
+ max(seqlen_q, seqlen_k) >= 2048
|
|
|
+ and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
|
|
+ ):
|
|
|
+ pytest.skip() # Reference implementation OOM
|
|
|
+ if swap_sq_sk:
|
|
|
+ seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
|
|
+ device = "cuda"
|
|
|
+ # set seed
|
|
|
+ torch.random.manual_seed(0)
|
|
|
+ batch_size = 2
|
|
|
+ nheads = 9
|
|
|
+ window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
+ q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
|
|
+ key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
|
|
+ (
|
|
|
+ q_unpad,
|
|
|
+ k_unpad,
|
|
|
+ v_unpad,
|
|
|
+ cu_seqlens_q,
|
|
|
+ cu_seqlens_k,
|
|
|
+ max_seqlen_q,
|
|
|
+ max_seqlen_k,
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ output_pad_fn,
|
|
|
+ dq_pad_fn,
|
|
|
+ dk_pad_fn,
|
|
|
+ ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
|
|
+ out = flash_attn_varlen_func(
|
|
|
+ q_unpad,
|
|
|
+ k_unpad,
|
|
|
+ v_unpad,
|
|
|
+ cu_seqlens_q,
|
|
|
+ cu_seqlens_k,
|
|
|
+ max_seqlen_q,
|
|
|
+ max_seqlen_k,
|
|
|
+ 0.0,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
+ deterministic=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ g = torch.randn_like(out)
|
|
|
+ if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
+ dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
+ for _ in range(50):
|
|
|
+ dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
|
|
+ assert torch.equal(dv, dv)
|
|
|
+ assert torch.equal(dk, dk)
|
|
|
+ assert torch.equal(dq, dq)
|