|
@@ -1,6 +1,6 @@
|
|
|
---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
|
|
|
-+++ flash.py 2023-11-28 16:14:25.206128903 +0000
|
|
|
-@@ -31,39 +31,39 @@
|
|
|
+--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
|
|
|
++++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
|
|
|
+@@ -36,44 +36,44 @@
|
|
|
|
|
|
FLASH_VERSION = "0.0.0"
|
|
|
try:
|
|
@@ -15,9 +15,12 @@
|
|
|
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
|
|
-
|
|
|
- FLASH_VERSION = flash_attn.__version__
|
|
|
-- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
|
|
-- if flash_ver_parsed < (2, 3):
|
|
|
-- raise ImportError("Requires 2.3 for sliding window support")
|
|
|
+- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
|
|
+- if (
|
|
|
+- flash_ver_parsed != (2, 3, 6)
|
|
|
+- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
|
|
+- ):
|
|
|
+- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
|
|
+ #try:
|
|
|
+ # from ... import _C_flashattention # type: ignore[attr-defined]
|
|
|
+ # from ..._cpp_lib import _build_metadata
|
|
@@ -29,35 +32,41 @@
|
|
|
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
|
|
+
|
|
|
+ FLASH_VERSION = flash_attn.__version__
|
|
|
-+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
|
|
-+ # if flash_ver_parsed < (2, 3):
|
|
|
-+ # raise ImportError("Requires 2.3 for sliding window support")
|
|
|
++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
|
|
++ # if (
|
|
|
++ # flash_ver_parsed != (2, 3, 6)
|
|
|
++ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
|
|
++ # ):
|
|
|
++ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
|
|
|
|
|
# create library so that flash-attn goes through the PyTorch Dispatcher
|
|
|
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
|
|
|
-+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
|
|
-
|
|
|
+-
|
|
|
- _flash_lib.define(
|
|
|
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
|
|
-- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
|
|
+- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
|
|
- "int max_seqlen_q, int max_seqlen_k, "
|
|
|
- "float p, float softmax_scale, "
|
|
|
-- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
|
|
+- "bool is_causal, int window_left, "
|
|
|
+- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
|
|
- )
|
|
|
--
|
|
|
++ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
|
|
+
|
|
|
- _flash_lib.define(
|
|
|
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
|
|
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
|
|
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
|
|
- "int max_seqlen_q, int max_seqlen_k, "
|
|
|
-- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
|
|
+- "float p, float softmax_scale, bool is_causal, "
|
|
|
+- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
|
|
- )
|
|
|
+ #_flash_lib.define(
|
|
|
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
|
|
-+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
|
|
++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
|
|
+ # "int max_seqlen_q, int max_seqlen_k, "
|
|
|
+ # "float p, float softmax_scale, "
|
|
|
-+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
|
|
++ # "bool is_causal, int window_left, "
|
|
|
++ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
|
|
+ #)
|
|
|
+
|
|
|
+ #_flash_lib.define(
|
|
@@ -65,52 +74,61 @@
|
|
|
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
|
|
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
|
|
+ # "int max_seqlen_q, int max_seqlen_k, "
|
|
|
-+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
|
|
++ # "float p, float softmax_scale, bool is_causal, "
|
|
|
++ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
|
|
+ #)
|
|
|
|
|
|
def _flash_fwd(
|
|
|
query,
|
|
|
-@@ -98,8 +98,8 @@
|
|
|
+@@ -111,8 +111,8 @@
|
|
|
p,
|
|
|
softmax_scale,
|
|
|
is_causal,
|
|
|
-- window_size - 1, # window_size_left
|
|
|
-- -1, # window_size_right
|
|
|
-+ # window_size - 1, # window_size_left
|
|
|
-+ # -1, # window_size_right
|
|
|
+- window_left, # window_size_left
|
|
|
+- window_right, # window_size_right
|
|
|
++ # window_left, # window_size_left
|
|
|
++ # window_right, # window_size_right
|
|
|
return_softmax,
|
|
|
None, # rng
|
|
|
)
|
|
|
-@@ -127,8 +127,8 @@
|
|
|
+@@ -134,15 +134,15 @@
|
|
|
+ out,
|
|
|
+ cu_seq_lens_q,
|
|
|
+ cu_seq_lens_k,
|
|
|
+- seqused_k,
|
|
|
++ # seqused_k,
|
|
|
+ max_seq_len_q,
|
|
|
+ max_seq_len_k,
|
|
|
+ p,
|
|
|
softmax_scale,
|
|
|
False,
|
|
|
is_causal,
|
|
|
-- window_size - 1, # window_size_left
|
|
|
-- -1, # window_size_right
|
|
|
-+ # window_size - 1, # window_size_left
|
|
|
-+ # -1, # window_size_right
|
|
|
+- window_left,
|
|
|
+- window_right,
|
|
|
++ # window_left,
|
|
|
++ # window_right,
|
|
|
return_softmax,
|
|
|
None,
|
|
|
)
|
|
|
-@@ -169,8 +169,8 @@
|
|
|
+@@ -184,8 +184,8 @@
|
|
|
p,
|
|
|
softmax_scale,
|
|
|
is_causal,
|
|
|
-- window_size - 1, # window_size_left
|
|
|
-- -1, # window_size_right
|
|
|
-+ # window_size - 1, # window_size_left
|
|
|
-+ # -1, # window_size_right
|
|
|
+- window_left,
|
|
|
+- window_right,
|
|
|
++ # window_left,
|
|
|
++ # window_right,
|
|
|
None,
|
|
|
rng_state,
|
|
|
)
|
|
|
-@@ -193,15 +193,15 @@
|
|
|
+@@ -208,15 +208,15 @@
|
|
|
softmax_scale,
|
|
|
False, # zero_tensors
|
|
|
is_causal,
|
|
|
-- window_size - 1, # window_size_left
|
|
|
-- -1, # window_size_right
|
|
|
-+ # window_size - 1, # window_size_left
|
|
|
-+ # -1, # window_size_right
|
|
|
+- window_left,
|
|
|
+- window_right,
|
|
|
++ # window_left,
|
|
|
++ # window_right,
|
|
|
None,
|
|
|
rng_state,
|
|
|
)
|
|
@@ -123,7 +141,7 @@
|
|
|
except ImportError:
|
|
|
pass
|
|
|
|
|
|
-@@ -348,7 +348,7 @@
|
|
|
+@@ -400,7 +400,7 @@
|
|
|
implementation.
|
|
|
"""
|
|
|
|