--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 +++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 @@ -36,44 +36,44 @@ FLASH_VERSION = "0.0.0" try: - try: - from ... import _C_flashattention # type: ignore[attr-defined] - from ..._cpp_lib import _build_metadata - - if _build_metadata is not None: - FLASH_VERSION = _build_metadata.flash_version - except ImportError: - import flash_attn - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention - - FLASH_VERSION = flash_attn.__version__ - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) - if ( - flash_ver_parsed != (2, 3, 6) - and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" - ): - raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") + #try: + # from ... import _C_flashattention # type: ignore[attr-defined] + # from ..._cpp_lib import _build_metadata + + # if _build_metadata is not None: + # FLASH_VERSION = _build_metadata.flash_version + #except ImportError: + import flash_attn + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + + FLASH_VERSION = flash_attn.__version__ + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) + # if ( + # flash_ver_parsed != (2, 3, 6) + # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" + # ): + # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") # create library so that flash-attn goes through the PyTorch Dispatcher - _flash_lib = torch.library.Library("xformers_flash", "DEF") - - _flash_lib.define( - "flash_fwd(Tensor query, Tensor key, Tensor value, " - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " - "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, " - "bool is_causal, int window_left, " - "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" - ) + #_flash_lib = torch.library.Library("xformers_flash", "DEF") - _flash_lib.define( - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " - "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, bool is_causal, " - "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" - ) + #_flash_lib.define( + # "flash_fwd(Tensor query, Tensor key, Tensor value, " + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " + # "int max_seqlen_q, int max_seqlen_k, " + # "float p, float softmax_scale, " + # "bool is_causal, int window_left, " + # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" + #) + + #_flash_lib.define( + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + # "int max_seqlen_q, int max_seqlen_k, " + # "float p, float softmax_scale, bool is_causal, " + # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + #) def _flash_fwd( query, @@ -111,8 +111,8 @@ p, softmax_scale, is_causal, - window_left, # window_size_left - window_right, # window_size_right + # window_left, # window_size_left + # window_right, # window_size_right return_softmax, None, # rng ) @@ -134,15 +134,15 @@ out, cu_seq_lens_q, cu_seq_lens_k, - seqused_k, + # seqused_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, False, is_causal, - window_left, - window_right, + # window_left, + # window_right, return_softmax, None, ) @@ -184,8 +184,8 @@ p, softmax_scale, is_causal, - window_left, - window_right, + # window_left, + # window_right, None, rng_state, ) @@ -208,15 +208,15 @@ softmax_scale, False, # zero_tensors is_causal, - window_left, - window_right, + # window_left, + # window_right, None, rng_state, ) return dq, dk, dv - _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") - _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") + #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") + #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") except ImportError: pass @@ -400,7 +400,7 @@ implementation. """ - OPERATOR = get_operator("xformers_flash", "flash_fwd") + OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}