flashpy_xformers-0.0.22.post7.rocm.patch 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. --- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
  2. +++ flash.py 2023-11-28 16:14:25.206128903 +0000
  3. @@ -31,39 +31,39 @@
  4. FLASH_VERSION = "0.0.0"
  5. try:
  6. - try:
  7. - from ... import _C_flashattention # type: ignore[attr-defined]
  8. - from ..._cpp_lib import _build_metadata
  9. -
  10. - if _build_metadata is not None:
  11. - FLASH_VERSION = _build_metadata.flash_version
  12. - except ImportError:
  13. - import flash_attn
  14. - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
  15. -
  16. - FLASH_VERSION = flash_attn.__version__
  17. - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
  18. - if flash_ver_parsed < (2, 3):
  19. - raise ImportError("Requires 2.3 for sliding window support")
  20. + #try:
  21. + # from ... import _C_flashattention # type: ignore[attr-defined]
  22. + # from ..._cpp_lib import _build_metadata
  23. +
  24. + # if _build_metadata is not None:
  25. + # FLASH_VERSION = _build_metadata.flash_version
  26. + #except ImportError:
  27. + import flash_attn
  28. + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
  29. +
  30. + FLASH_VERSION = flash_attn.__version__
  31. + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
  32. + # if flash_ver_parsed < (2, 3):
  33. + # raise ImportError("Requires 2.3 for sliding window support")
  34. # create library so that flash-attn goes through the PyTorch Dispatcher
  35. - _flash_lib = torch.library.Library("xformers_flash", "DEF")
  36. + #_flash_lib = torch.library.Library("xformers_flash", "DEF")
  37. - _flash_lib.define(
  38. - "flash_fwd(Tensor query, Tensor key, Tensor value, "
  39. - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
  40. - "int max_seqlen_q, int max_seqlen_k, "
  41. - "float p, float softmax_scale, "
  42. - "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
  43. - )
  44. -
  45. - _flash_lib.define(
  46. - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
  47. - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
  48. - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
  49. - "int max_seqlen_q, int max_seqlen_k, "
  50. - "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
  51. - )
  52. + #_flash_lib.define(
  53. + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
  54. + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
  55. + # "int max_seqlen_q, int max_seqlen_k, "
  56. + # "float p, float softmax_scale, "
  57. + # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
  58. + #)
  59. +
  60. + #_flash_lib.define(
  61. + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
  62. + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
  63. + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
  64. + # "int max_seqlen_q, int max_seqlen_k, "
  65. + # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
  66. + #)
  67. def _flash_fwd(
  68. query,
  69. @@ -98,8 +98,8 @@
  70. p,
  71. softmax_scale,
  72. is_causal,
  73. - window_size - 1, # window_size_left
  74. - -1, # window_size_right
  75. + # window_size - 1, # window_size_left
  76. + # -1, # window_size_right
  77. return_softmax,
  78. None, # rng
  79. )
  80. @@ -127,8 +127,8 @@
  81. softmax_scale,
  82. False,
  83. is_causal,
  84. - window_size - 1, # window_size_left
  85. - -1, # window_size_right
  86. + # window_size - 1, # window_size_left
  87. + # -1, # window_size_right
  88. return_softmax,
  89. None,
  90. )
  91. @@ -169,8 +169,8 @@
  92. p,
  93. softmax_scale,
  94. is_causal,
  95. - window_size - 1, # window_size_left
  96. - -1, # window_size_right
  97. + # window_size - 1, # window_size_left
  98. + # -1, # window_size_right
  99. None,
  100. rng_state,
  101. )
  102. @@ -193,15 +193,15 @@
  103. softmax_scale,
  104. False, # zero_tensors
  105. is_causal,
  106. - window_size - 1, # window_size_left
  107. - -1, # window_size_right
  108. + # window_size - 1, # window_size_left
  109. + # -1, # window_size_right
  110. None,
  111. rng_state,
  112. )
  113. return dq, dk, dv
  114. - _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
  115. - _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
  116. + #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
  117. + #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
  118. except ImportError:
  119. pass
  120. @@ -348,7 +348,7 @@
  121. implementation.
  122. """
  123. - OPERATOR = get_operator("xformers_flash", "flash_fwd")
  124. + OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
  125. SUPPORTED_DEVICES: Set[str] = {"cuda"}
  126. CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
  127. SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}