|
@@ -6,8 +6,7 @@ import torch.nn.functional as F
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
|
|
-# from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
|
|
|
-from test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
|
|
|
+from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
|
|
|
|
|
|
ABS_TOL = 5e-3
|
|
|
REL_TOL = 1e-1
|