123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- --- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
- +++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
- @@ -36,44 +36,44 @@
-
- FLASH_VERSION = "0.0.0"
- try:
- - try:
- - from ... import _C_flashattention # type: ignore[attr-defined]
- - from ..._cpp_lib import _build_metadata
- -
- - if _build_metadata is not None:
- - FLASH_VERSION = _build_metadata.flash_version
- - except ImportError:
- - import flash_attn
- - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
- -
- - FLASH_VERSION = flash_attn.__version__
- - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- - if (
- - flash_ver_parsed != (2, 3, 6)
- - and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- - ):
- - raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
- + #try:
- + # from ... import _C_flashattention # type: ignore[attr-defined]
- + # from ..._cpp_lib import _build_metadata
- +
- + # if _build_metadata is not None:
- + # FLASH_VERSION = _build_metadata.flash_version
- + #except ImportError:
- + import flash_attn
- + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
- +
- + FLASH_VERSION = flash_attn.__version__
- + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- + # if (
- + # flash_ver_parsed != (2, 3, 6)
- + # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- + # ):
- + # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
-
- # create library so that flash-attn goes through the PyTorch Dispatcher
- - _flash_lib = torch.library.Library("xformers_flash", "DEF")
- -
- - _flash_lib.define(
- - "flash_fwd(Tensor query, Tensor key, Tensor value, "
- - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- - "int max_seqlen_q, int max_seqlen_k, "
- - "float p, float softmax_scale, "
- - "bool is_causal, int window_left, "
- - "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- - )
- + #_flash_lib = torch.library.Library("xformers_flash", "DEF")
-
- - _flash_lib.define(
- - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- - "int max_seqlen_q, int max_seqlen_k, "
- - "float p, float softmax_scale, bool is_causal, "
- - "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- - )
- + #_flash_lib.define(
- + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
- + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- + # "int max_seqlen_q, int max_seqlen_k, "
- + # "float p, float softmax_scale, "
- + # "bool is_causal, int window_left, "
- + # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- + #)
- +
- + #_flash_lib.define(
- + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- + # "int max_seqlen_q, int max_seqlen_k, "
- + # "float p, float softmax_scale, bool is_causal, "
- + # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- + #)
-
- def _flash_fwd(
- query,
- @@ -111,8 +111,8 @@
- p,
- softmax_scale,
- is_causal,
- - window_left, # window_size_left
- - window_right, # window_size_right
- + # window_left, # window_size_left
- + # window_right, # window_size_right
- return_softmax,
- None, # rng
- )
- @@ -134,15 +134,15 @@
- out,
- cu_seq_lens_q,
- cu_seq_lens_k,
- - seqused_k,
- + # seqused_k,
- max_seq_len_q,
- max_seq_len_k,
- p,
- softmax_scale,
- False,
- is_causal,
- - window_left,
- - window_right,
- + # window_left,
- + # window_right,
- return_softmax,
- None,
- )
- @@ -184,8 +184,8 @@
- p,
- softmax_scale,
- is_causal,
- - window_left,
- - window_right,
- + # window_left,
- + # window_right,
- None,
- rng_state,
- )
- @@ -208,15 +208,15 @@
- softmax_scale,
- False, # zero_tensors
- is_causal,
- - window_left,
- - window_right,
- + # window_left,
- + # window_right,
- None,
- rng_state,
- )
- return dq, dk, dv
-
- - _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- - _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
- + #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- + #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
- except ImportError:
- pass
-
- @@ -400,7 +400,7 @@
- implementation.
- """
-
- - OPERATOR = get_operator("xformers_flash", "flash_fwd")
- + OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
- SUPPORTED_DEVICES: Set[str] = {"cuda"}
- CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
- SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|