|
@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
|
|
|
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
|
|
|
|
|
|
|
|
|
+def attn_bias_from_alibi_slopes(
|
|
|
+ slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
|
|
|
+):
|
|
|
+ batch, nheads = slopes.shape
|
|
|
+ device = slopes.device
|
|
|
+ slopes = rearrange(slopes, "b h -> b h 1 1")
|
|
|
+ if causal:
|
|
|
+ return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
|
|
|
+ else:
|
|
|
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
|
|
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
|
|
+ sk = (
|
|
|
+ seqlen_k
|
|
|
+ if key_padding_mask is None
|
|
|
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
|
+ )
|
|
|
+ sq = (
|
|
|
+ seqlen_q
|
|
|
+ if query_padding_mask is None
|
|
|
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
|
+ )
|
|
|
+ relative_pos = torch.abs(row_idx + sk - sq - col_idx)
|
|
|
+ return -slopes * relative_pos.to(dtype=slopes.dtype)
|
|
|
+
|
|
|
+
|
|
|
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
|
|
|
assert mode in ["full", "random", "third"]
|
|
|
if mode == "full":
|
|
@@ -186,6 +211,7 @@ def attention_ref(
|
|
|
v,
|
|
|
query_padding_mask=None,
|
|
|
key_padding_mask=None,
|
|
|
+ attn_bias=None,
|
|
|
dropout_p=0.0,
|
|
|
dropout_mask=None,
|
|
|
causal=False,
|
|
@@ -200,6 +226,7 @@ def attention_ref(
|
|
|
v: (batch_size, seqlen_k, nheads_k, head_dim)
|
|
|
query_padding_mask: (batch_size, seqlen_q)
|
|
|
key_padding_mask: (batch_size, seqlen_k)
|
|
|
+ attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
|
|
dropout_p: float
|
|
|
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
|
|
causal: whether to apply causal masking
|
|
@@ -238,7 +265,9 @@ def attention_ref(
|
|
|
q.device,
|
|
|
)
|
|
|
scores.masked_fill_(local_mask, float("-inf"))
|
|
|
- attention = torch.softmax(scores, dim=-1)
|
|
|
+ if attn_bias is not None:
|
|
|
+ scores = scores + attn_bias
|
|
|
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
|
|
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
|
|
if window_size[0] >= 0 or window_size[1] >= 0:
|
|
|
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
|
|
@@ -264,6 +293,7 @@ def attention_kvpacked_ref(
|
|
|
kv,
|
|
|
query_padding_mask=None,
|
|
|
key_padding_mask=None,
|
|
|
+ attn_bias=None,
|
|
|
dropout_p=0.0,
|
|
|
dropout_mask=None,
|
|
|
causal=False,
|
|
@@ -277,6 +307,7 @@ def attention_kvpacked_ref(
|
|
|
kv[:, :, 1],
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
upcast=upcast,
|
|
@@ -289,6 +320,7 @@ def attention_kvpacked_ref(
|
|
|
def attention_qkvpacked_ref(
|
|
|
qkv,
|
|
|
key_padding_mask=None,
|
|
|
+ attn_bias=None,
|
|
|
dropout_p=0.0,
|
|
|
dropout_mask=None,
|
|
|
causal=False,
|
|
@@ -302,6 +334,7 @@ def attention_qkvpacked_ref(
|
|
|
qkv[:, :, 2],
|
|
|
key_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
upcast=upcast,
|
|
@@ -436,6 +469,7 @@ def normalize_flash_attn_S(
|
|
|
v,
|
|
|
query_padding_mask=None,
|
|
|
key_padding_mask=None,
|
|
|
+ attn_bias=None,
|
|
|
is_dropout=False,
|
|
|
causal=False,
|
|
|
window_size=(-1, -1), # -1 means infinite window size
|
|
@@ -445,6 +479,7 @@ def normalize_flash_attn_S(
|
|
|
q: (batch_size, seqlen_q, nheads, head_dim)
|
|
|
k, v: (batch_size, seqlen_k, nheads, head_dim)
|
|
|
key_padding_mask: (batch_size, seqlen_q)
|
|
|
+ attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
|
|
Output:
|
|
|
softmax_lse: (batch_size, nheads, seqlen_q)
|
|
|
softmax_max: (batch_size, nheads, seqlen_q)
|
|
@@ -467,6 +502,8 @@ def normalize_flash_attn_S(
|
|
|
q.device,
|
|
|
)
|
|
|
scores.masked_fill_(local_mask, float("-inf"))
|
|
|
+ if attn_bias is not None:
|
|
|
+ scores = scores + attn_bias.to(dtype=scores.dtype)
|
|
|
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
|
|
|
scores_block = scores.split(block_size_n, dim=-1)
|
|
|
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
|
|
@@ -529,6 +566,8 @@ def get_dropout_fraction(
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [True])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@@ -538,24 +577,34 @@ def get_dropout_fraction(
|
|
|
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
|
|
|
# @pytest.mark.parametrize("d", [64])
|
|
|
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
|
|
-@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
|
|
-# @pytest.mark.parametrize("seqlen", [128])
|
|
|
+@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
|
|
|
+# @pytest.mark.parametrize("seqlen", [97])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize("dropout_p", [0.0])
|
|
|
-def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
+def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
|
|
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 13
|
|
|
+ batch_size = 8
|
|
|
nheads = 9
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
|
|
|
qkv = torch.randn(
|
|
|
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
|
|
|
)
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
out, lse, S_dmask = flash_attn_qkvpacked_func(
|
|
|
- qkv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
|
|
|
+ qkv,
|
|
|
+ dropout_p,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
if dropout_p > 0.0:
|
|
|
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
|
@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
qkv[:, :, 2],
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p > 0.0,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
dropout_mask = None
|
|
|
|
|
|
out_ref, attn_ref = attention_qkvpacked_ref(
|
|
|
- qkv, None, dropout_p, dropout_mask, causal=causal, window_size=window_size
|
|
|
+ qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
|
|
|
)
|
|
|
out_pt, attn_pt = attention_qkvpacked_ref(
|
|
|
qkv,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -651,7 +702,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
|
|
- assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
+ # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
|
|
|
+ if not alibi:
|
|
|
+ assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize('dtype', [torch.float16])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [True])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize('causal', [False])
|
|
|
-@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [64])
|
|
|
-@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
|
|
+@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
|
|
|
# @pytest.mark.parametrize('seqlen', [128])
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize('dropout_p', [0.0])
|
|
|
-def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
|
|
|
+def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
|
|
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
|
|
pytest.skip() # Reference implementation OOM
|
|
|
device = "cuda"
|
|
@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
|
|
|
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
|
|
|
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(
|
|
|
+ alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
|
|
|
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
|
|
|
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
|
|
@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
dropout_p,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
return_attn_probs=True,
|
|
|
)
|
|
|
out = output_pad_fn(out_unpad)
|
|
@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
qkv[:, :, 2],
|
|
|
key_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p > 0.0,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
dropout_mask = None
|
|
|
|
|
|
out_ref, attn_ref = attention_qkvpacked_ref(
|
|
|
- qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal, window_size=window_size
|
|
|
+ qkv,
|
|
|
+ key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
+ dropout_p,
|
|
|
+ dropout_mask,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
)
|
|
|
out_pt, attn_pt = attention_qkvpacked_ref(
|
|
|
qkv,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -774,7 +845,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
|
|
- assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
+ # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
|
|
|
+ if not alibi:
|
|
|
+ assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
|
@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [True])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize("causal", [True])
|
|
|
-@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
|
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
|
@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize("dropout_p", [0.17])
|
|
|
def test_flash_attn_output(
|
|
|
- seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked
|
|
|
+ seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
|
|
|
):
|
|
|
if (
|
|
|
max(seqlen_q, seqlen_k) >= 2048
|
|
@@ -825,7 +900,7 @@ def test_flash_attn_output(
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 13
|
|
|
+ batch_size = 8
|
|
|
nheads = 9
|
|
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
assert nheads % nheads_k == 0
|
|
@@ -842,14 +917,32 @@ def test_flash_attn_output(
|
|
|
v = torch.randn(
|
|
|
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
|
|
)
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
|
|
|
if kvpacked:
|
|
|
out, lse, S_dmask = flash_attn_kvpacked_func(
|
|
|
- q, kv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
|
|
|
+ q,
|
|
|
+ kv,
|
|
|
+ dropout_p,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
else:
|
|
|
out, lse, S_dmask = flash_attn_func(
|
|
|
- q, k, v, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ dropout_p,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
if dropout_p > 0.0:
|
|
|
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
|
@@ -878,6 +971,7 @@ def test_flash_attn_output(
|
|
|
v_rep,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p > 0.0,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
@@ -895,6 +989,7 @@ def test_flash_attn_output(
|
|
|
kv,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -905,6 +1000,7 @@ def test_flash_attn_output(
|
|
|
kv,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -919,6 +1015,7 @@ def test_flash_attn_output(
|
|
|
v,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -930,6 +1027,7 @@ def test_flash_attn_output(
|
|
|
v,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -1000,7 +1098,9 @@ def test_flash_attn_output(
|
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
|
|
- assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
+ # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
|
|
|
+ if not alibi:
|
|
|
+ assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
|
@@ -1014,11 +1114,13 @@ def test_flash_attn_output(
|
|
|
# @pytest.mark.parametrize('dtype', [torch.float16])
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
# @pytest.mark.parametrize('mha_type', ["mqa"])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [True])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize('causal', [True])
|
|
|
-@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [64])
|
|
|
@pytest.mark.parametrize(
|
|
@@ -1041,7 +1143,7 @@ def test_flash_attn_output(
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
|
|
# @pytest.mark.parametrize('dropout_p', [0.0])
|
|
|
def test_flash_attn_varlen_output(
|
|
|
- seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked
|
|
|
+ seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
|
|
|
):
|
|
|
if (
|
|
|
max(seqlen_q, seqlen_k) >= 2048
|
|
@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output(
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 13
|
|
|
+ batch_size = 8
|
|
|
nheads = 9
|
|
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
|
|
assert nheads % nheads_k == 0
|
|
@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output(
|
|
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
|
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
|
|
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(
|
|
|
+ alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
|
|
|
if kvpacked:
|
|
|
(
|
|
@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output(
|
|
|
max_seqlen_q,
|
|
|
max_seqlen_k,
|
|
|
dropout_p,
|
|
|
- return_attn_probs=True,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
else:
|
|
|
(
|
|
@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output(
|
|
|
max_seqlen_q,
|
|
|
max_seqlen_k,
|
|
|
dropout_p,
|
|
|
- return_attn_probs=True,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
out = output_pad_fn(out_unpad)
|
|
|
if dropout_p > 0.0:
|
|
@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output(
|
|
|
v_rep,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p > 0.0,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output(
|
|
|
kv,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output(
|
|
|
kv,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output(
|
|
|
v,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output(
|
|
|
v,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
dropout_p,
|
|
|
dropout_mask,
|
|
|
causal=causal,
|
|
@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
|
|
- assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
+ # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
|
|
|
+ if not alibi:
|
|
|
+ assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
|
|
|
|
|
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
- assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
|
|
- assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
|
|
- assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
|
|
+ assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
|
|
+ assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
|
|
+ assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
causal = True
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 13
|
|
|
+ batch_size = 8
|
|
|
nheads = 9
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
|
|
|
out_ref, attn_ref = attention_ref(
|
|
|
- q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size
|
|
|
+ q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
|
|
|
)
|
|
|
out_pt, attn_pt = attention_ref(
|
|
|
q,
|
|
@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
|
|
v,
|
|
|
None,
|
|
|
None,
|
|
|
+ None,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
causal = True
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
|
- batch_size = 13
|
|
|
+ batch_size = 8
|
|
|
nheads = 9
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
v,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ None,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
v,
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
+ None,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
-# @pytest.mark.parametrize("local", [True])
|
|
|
+# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
# @pytest.mark.parametrize("causal", [True])
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
|
@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
|
|
],
|
|
|
)
|
|
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
|
|
-def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
|
|
|
+def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype):
|
|
|
if swap_sq_sk:
|
|
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
|
|
device = "cuda"
|
|
@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
out, lse, _ = flash_attn_func(
|
|
|
- q, k, v, 0.0, causal=causal, window_size=window_size, return_attn_probs=True
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ 0.0,
|
|
|
+ causal=causal,
|
|
|
+ window_size=window_size,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
+ return_attn_probs=True,
|
|
|
)
|
|
|
out_ref, attn_ref = attention_ref(
|
|
|
- q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size
|
|
|
+ q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size
|
|
|
)
|
|
|
out_pt, attn_pt = attention_ref(
|
|
|
q,
|
|
@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
v,
|
|
|
None,
|
|
|
None,
|
|
|
+ attn_bias,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
+ mult = 2 if not alibi else 8
|
|
|
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
|
|
- assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
|
|
- assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
|
|
- assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
|
|
+ assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
-# @pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
+# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
|
|
+@pytest.mark.parametrize("dtype", [torch.float16])
|
|
|
@pytest.mark.parametrize("num_splits", [1, 0])
|
|
|
-# @pytest.mark.parametrize("num_splits", [0])
|
|
|
+# @pytest.mark.parametrize("num_splits", [1])
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
@pytest.mark.parametrize("new_kv", [False, True])
|
|
|
-# @pytest.mark.parametrize("new_kv", [True])
|
|
|
+# @pytest.mark.parametrize("new_kv", [False])
|
|
|
+@pytest.mark.parametrize("alibi", [False, True])
|
|
|
+# @pytest.mark.parametrize("alibi", [True])
|
|
|
@pytest.mark.parametrize("local", [False, True])
|
|
|
# @pytest.mark.parametrize("local", [False])
|
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
|
-# @pytest.mark.parametrize("causal", [True])
|
|
|
+# @pytest.mark.parametrize("causal", [False])
|
|
|
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
|
|
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
|
|
|
@pytest.mark.parametrize("rotary_interleaved", [False, True])
|
|
@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
|
|
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
|
|
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
|
|
@pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
-# @pytest.mark.parametrize("has_batch_idx", [True])
|
|
|
+# @pytest.mark.parametrize("has_batch_idx", [False])
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache(
|
|
|
seqlen_new_eq_seqlen_q,
|
|
|
causal,
|
|
|
local,
|
|
|
+ alibi,
|
|
|
new_kv,
|
|
|
mha_type,
|
|
|
num_splits,
|
|
@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache(
|
|
|
dtype=torch.int32,
|
|
|
device=device,
|
|
|
)
|
|
|
+ arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
|
|
+ cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
|
|
+ key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
|
|
|
if has_batch_idx:
|
|
|
- cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
|
|
|
+ cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
|
|
|
+ :batch_size
|
|
|
+ ]
|
|
|
else:
|
|
|
cache_batch_idx = None
|
|
|
+ if alibi:
|
|
|
+ alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
|
|
+ attn_bias = attn_bias_from_alibi_slopes(
|
|
|
+ alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ alibi_slopes, attn_bias = None, None
|
|
|
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
|
|
if rotary_dim > 0:
|
|
|
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
|
|
@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache(
|
|
|
# k_cache[:, 64:] = -1
|
|
|
k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
|
|
|
v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
|
|
|
- arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
|
|
- cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
|
|
if new_kv:
|
|
|
update_mask = torch.logical_and(
|
|
|
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
|
|
@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache(
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
rotary_interleaved=rotary_interleaved,
|
|
|
+ alibi_slopes=alibi_slopes,
|
|
|
num_splits=num_splits,
|
|
|
)
|
|
|
# out = flash_attn_with_kvcache(
|
|
@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache(
|
|
|
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
|
|
|
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
|
|
|
# probs = torch.softmax(qk, dim=-1)
|
|
|
- key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
|
|
|
out_ref, _ = attention_ref(
|
|
|
q_ro,
|
|
|
k_cache_rep,
|
|
|
v_cache_rep,
|
|
|
None,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache(
|
|
|
v_cache_rep,
|
|
|
None,
|
|
|
key_padding_mask,
|
|
|
+ attn_bias,
|
|
|
0.0,
|
|
|
None,
|
|
|
causal=causal,
|
|
@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache(
|
|
|
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
|
|
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
|
|
|
assert torch.equal(v_cache_select, v_cache_ref)
|
|
|
- assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
+ mult = 3 if not alibi else 5
|
|
|
+ assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
|
|
|
|
|
|
|
|
|
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|