--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 +++ flash.py 2023-11-28 16:14:25.206128903 +0000 @@ -31,39 +31,39 @@ FLASH_VERSION = "0.0.0" try: - try: - from ... import _C_flashattention # type: ignore[attr-defined] - from ..._cpp_lib import _build_metadata - - if _build_metadata is not None: - FLASH_VERSION = _build_metadata.flash_version - except ImportError: - import flash_attn - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention - - FLASH_VERSION = flash_attn.__version__ - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) - if flash_ver_parsed < (2, 3): - raise ImportError("Requires 2.3 for sliding window support") + #try: + # from ... import _C_flashattention # type: ignore[attr-defined] + # from ..._cpp_lib import _build_metadata + + # if _build_metadata is not None: + # FLASH_VERSION = _build_metadata.flash_version + #except ImportError: + import flash_attn + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + + FLASH_VERSION = flash_attn.__version__ + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) + # if flash_ver_parsed < (2, 3): + # raise ImportError("Requires 2.3 for sliding window support") # create library so that flash-attn goes through the PyTorch Dispatcher - _flash_lib = torch.library.Library("xformers_flash", "DEF") + #_flash_lib = torch.library.Library("xformers_flash", "DEF") - _flash_lib.define( - "flash_fwd(Tensor query, Tensor key, Tensor value, " - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " - "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, " - "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" - ) - - _flash_lib.define( - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " - "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" - ) + #_flash_lib.define( + # "flash_fwd(Tensor query, Tensor key, Tensor value, " + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " + # "int max_seqlen_q, int max_seqlen_k, " + # "float p, float softmax_scale, " + # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" + #) + + #_flash_lib.define( + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + # "int max_seqlen_q, int max_seqlen_k, " + # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + #) def _flash_fwd( query, @@ -98,8 +98,8 @@ p, softmax_scale, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + # window_size - 1, # window_size_left + # -1, # window_size_right return_softmax, None, # rng ) @@ -127,8 +127,8 @@ softmax_scale, False, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + # window_size - 1, # window_size_left + # -1, # window_size_right return_softmax, None, ) @@ -169,8 +169,8 @@ p, softmax_scale, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + # window_size - 1, # window_size_left + # -1, # window_size_right None, rng_state, ) @@ -193,15 +193,15 @@ softmax_scale, False, # zero_tensors is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + # window_size - 1, # window_size_left + # -1, # window_size_right None, rng_state, ) return dq, dk, dv - _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") - _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") + #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") + #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") except ImportError: pass @@ -348,7 +348,7 @@ implementation. """ - OPERATOR = get_operator("xformers_flash", "flash_fwd") + OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}