|
@@ -47,7 +47,7 @@ def generate_qkv(
|
|
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
|
|
|
|
if query_padding_mask is not None:
|
|
|
- q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
|
|
|
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)
|
|
|
output_pad_fn = lambda output_unpad: pad_input(
|
|
|
output_unpad, indices_q, batch_size, seqlen_q
|
|
|
)
|
|
@@ -62,8 +62,8 @@ def generate_qkv(
|
|
|
)
|
|
|
|
|
|
if key_padding_mask is not None:
|
|
|
- k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
|
|
|
- v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)
|
|
|
+ k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, *rest = unpad_input(k, key_padding_mask)
|
|
|
+ v_unpad, _, _, _, *rest = unpad_input(v, key_padding_mask)
|
|
|
else:
|
|
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
@@ -404,7 +404,7 @@ def test_flash_attn_output(
|
|
|
# breakpoint()
|
|
|
|
|
|
|
|
|
- if dtype != torch.float8_e4m3fn and not V_colmajor and softcap == 0.0:
|
|
|
+ if dtype != torch.float8_e4m3fn and not V_colmajor:
|
|
|
g = torch.randn_like(out)
|
|
|
do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
|
|
|
import flashattn_hopper_cuda
|
|
@@ -459,7 +459,7 @@ def test_flash_attn_output(
|
|
|
multiple = 2
|
|
|
assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()
|
|
|
|
|
|
- if dtype != torch.float8_e4m3fn and not V_colmajor and softcap == 0.0:
|
|
|
+ if dtype != torch.float8_e4m3fn and not V_colmajor:
|
|
|
multiple = 2 if softcap == 0.0 else 4
|
|
|
assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
|
|
|
assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
|
|
@@ -603,7 +603,7 @@ def test_flash_attn_varlen_output(
|
|
|
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
|
|
|
# breakpoint()
|
|
|
|
|
|
- if dtype != torch.float8_e4m3fn and softcap == 0.0:
|
|
|
+ if dtype != torch.float8_e4m3fn:
|
|
|
g_unpad = torch.randn_like(out_unpad)
|
|
|
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
|
|
import flashattn_hopper_cuda
|
|
@@ -667,10 +667,8 @@ def test_flash_attn_varlen_output(
|
|
|
# of a Pytorch implementation.
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
|
|
|
|
|
|
- if dtype != torch.float8_e4m3fn and softcap == 0.0:
|
|
|
+ if dtype != torch.float8_e4m3fn:
|
|
|
multiple = 2 if softcap == 0.0 else 4
|
|
|
assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
|
|
|
assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
|
|
|
assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item()
|
|
|
-
|
|
|
-
|