|
@@ -182,9 +182,14 @@ def construct_local_mask(
|
|
|
query_padding_mask=None,
|
|
|
key_padding_mask=None,
|
|
|
device=None,
|
|
|
+ key_leftpad=None,
|
|
|
):
|
|
|
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
|
|
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
|
|
+ if key_leftpad is not None:
|
|
|
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
|
|
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
|
|
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
|
|
sk = (
|
|
|
seqlen_k
|
|
|
if key_padding_mask is None
|
|
@@ -219,6 +224,7 @@ def attention_ref(
|
|
|
softcap=0.0,
|
|
|
upcast=True,
|
|
|
reorder_ops=False,
|
|
|
+ key_leftpad=None,
|
|
|
):
|
|
|
"""
|
|
|
Arguments:
|
|
@@ -268,6 +274,7 @@ def attention_ref(
|
|
|
query_padding_mask,
|
|
|
key_padding_mask,
|
|
|
q.device,
|
|
|
+ key_leftpad=key_leftpad,
|
|
|
)
|
|
|
scores.masked_fill_(local_mask, float("-inf"))
|
|
|
if attn_bias is not None:
|
|
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
|
|
|
softcap=0.0,
|
|
|
upcast=True,
|
|
|
reorder_ops=False,
|
|
|
+ key_leftpad=None,
|
|
|
):
|
|
|
return attention_ref(
|
|
|
q,
|
|
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
|
|
|
window_size=window_size,
|
|
|
softcap=softcap,
|
|
|
reorder_ops=reorder_ops,
|
|
|
+ key_leftpad=key_leftpad,
|
|
|
)
|
|
|
|
|
|
|
|
@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
|
|
|
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
|
|
@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
|
|
|
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
|
|
|
-# @pytest.mark.parametrize("paged_kv_block_size", [256])
|
|
|
-@pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
-# @pytest.mark.parametrize("has_batch_idx", [False])
|
|
|
+# @pytest.mark.parametrize("paged_kv_block_size", [None])
|
|
|
+@pytest.mark.parametrize("has_leftpad", [False, True])
|
|
|
+# @pytest.mark.parametrize("has_leftpad", [True])
|
|
|
+# @pytest.mark.parametrize("has_batch_idx", [False, True])
|
|
|
+@pytest.mark.parametrize("has_batch_idx", [False])
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
|
|
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
|
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
|
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
|
|
|
seqlen_k,
|
|
|
d,
|
|
|
has_batch_idx,
|
|
|
+ has_leftpad,
|
|
|
paged_kv_block_size,
|
|
|
rotary_fraction,
|
|
|
rotary_interleaved,
|
|
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
|
|
|
pytest.skip()
|
|
|
if has_batch_idx and paged_kv_block_size is not None:
|
|
|
pytest.skip()
|
|
|
+ if has_leftpad and paged_kv_block_size is not None:
|
|
|
+ pytest.skip()
|
|
|
device = "cuda"
|
|
|
# set seed
|
|
|
torch.random.manual_seed(0)
|
|
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
|
|
|
dtype=torch.int32,
|
|
|
device=device,
|
|
|
)
|
|
|
+ if has_leftpad:
|
|
|
+ cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
|
|
|
+ if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
|
|
|
+ for i in range(batch_size)])
|
|
|
+ else:
|
|
|
+ cache_leftpad = None
|
|
|
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
|
|
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
|
|
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
|
|
|
+ if has_leftpad:
|
|
|
+ key_padding_mask = torch.logical_and(
|
|
|
+ key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
|
|
|
+ )
|
|
|
if has_batch_idx:
|
|
|
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
|
|
|
:batch_size
|
|
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
|
|
|
rotary_sin=sin,
|
|
|
cache_seqlens=cache_seqlens,
|
|
|
cache_batch_idx=cache_batch_idx,
|
|
|
+ cache_leftpad=cache_leftpad,
|
|
|
block_table=block_table,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
|
|
|
None,
|
|
|
causal=causal,
|
|
|
window_size=window_size,
|
|
|
+ key_leftpad=cache_leftpad,
|
|
|
)
|
|
|
out_pt, _ = attention_ref(
|
|
|
q_ro,
|
|
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
|
|
|
window_size=window_size,
|
|
|
upcast=False,
|
|
|
reorder_ops=True,
|
|
|
+ key_leftpad=cache_leftpad,
|
|
|
)
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
|
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|