flashpy_xformers-0.0.23.rocm.patch 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. --- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
  2. +++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
  3. @@ -36,44 +36,44 @@
  4. FLASH_VERSION = "0.0.0"
  5. try:
  6. - try:
  7. - from ... import _C_flashattention # type: ignore[attr-defined]
  8. - from ..._cpp_lib import _build_metadata
  9. -
  10. - if _build_metadata is not None:
  11. - FLASH_VERSION = _build_metadata.flash_version
  12. - except ImportError:
  13. - import flash_attn
  14. - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
  15. -
  16. - FLASH_VERSION = flash_attn.__version__
  17. - flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
  18. - if (
  19. - flash_ver_parsed != (2, 3, 6)
  20. - and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
  21. - ):
  22. - raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
  23. + #try:
  24. + # from ... import _C_flashattention # type: ignore[attr-defined]
  25. + # from ..._cpp_lib import _build_metadata
  26. +
  27. + # if _build_metadata is not None:
  28. + # FLASH_VERSION = _build_metadata.flash_version
  29. + #except ImportError:
  30. + import flash_attn
  31. + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
  32. +
  33. + FLASH_VERSION = flash_attn.__version__
  34. + # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
  35. + # if (
  36. + # flash_ver_parsed != (2, 3, 6)
  37. + # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
  38. + # ):
  39. + # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
  40. # create library so that flash-attn goes through the PyTorch Dispatcher
  41. - _flash_lib = torch.library.Library("xformers_flash", "DEF")
  42. -
  43. - _flash_lib.define(
  44. - "flash_fwd(Tensor query, Tensor key, Tensor value, "
  45. - "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
  46. - "int max_seqlen_q, int max_seqlen_k, "
  47. - "float p, float softmax_scale, "
  48. - "bool is_causal, int window_left, "
  49. - "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
  50. - )
  51. + #_flash_lib = torch.library.Library("xformers_flash", "DEF")
  52. - _flash_lib.define(
  53. - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
  54. - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
  55. - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
  56. - "int max_seqlen_q, int max_seqlen_k, "
  57. - "float p, float softmax_scale, bool is_causal, "
  58. - "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
  59. - )
  60. + #_flash_lib.define(
  61. + # "flash_fwd(Tensor query, Tensor key, Tensor value, "
  62. + # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
  63. + # "int max_seqlen_q, int max_seqlen_k, "
  64. + # "float p, float softmax_scale, "
  65. + # "bool is_causal, int window_left, "
  66. + # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
  67. + #)
  68. +
  69. + #_flash_lib.define(
  70. + # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
  71. + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
  72. + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
  73. + # "int max_seqlen_q, int max_seqlen_k, "
  74. + # "float p, float softmax_scale, bool is_causal, "
  75. + # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
  76. + #)
  77. def _flash_fwd(
  78. query,
  79. @@ -111,8 +111,8 @@
  80. p,
  81. softmax_scale,
  82. is_causal,
  83. - window_left, # window_size_left
  84. - window_right, # window_size_right
  85. + # window_left, # window_size_left
  86. + # window_right, # window_size_right
  87. return_softmax,
  88. None, # rng
  89. )
  90. @@ -134,15 +134,15 @@
  91. out,
  92. cu_seq_lens_q,
  93. cu_seq_lens_k,
  94. - seqused_k,
  95. + # seqused_k,
  96. max_seq_len_q,
  97. max_seq_len_k,
  98. p,
  99. softmax_scale,
  100. False,
  101. is_causal,
  102. - window_left,
  103. - window_right,
  104. + # window_left,
  105. + # window_right,
  106. return_softmax,
  107. None,
  108. )
  109. @@ -184,8 +184,8 @@
  110. p,
  111. softmax_scale,
  112. is_causal,
  113. - window_left,
  114. - window_right,
  115. + # window_left,
  116. + # window_right,
  117. None,
  118. rng_state,
  119. )
  120. @@ -208,15 +208,15 @@
  121. softmax_scale,
  122. False, # zero_tensors
  123. is_causal,
  124. - window_left,
  125. - window_right,
  126. + # window_left,
  127. + # window_right,
  128. None,
  129. rng_state,
  130. )
  131. return dq, dk, dv
  132. - _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
  133. - _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
  134. + #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
  135. + #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
  136. except ImportError:
  137. pass
  138. @@ -400,7 +400,7 @@
  139. implementation.
  140. """
  141. - OPERATOR = get_operator("xformers_flash", "flash_fwd")
  142. + OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
  143. SUPPORTED_DEVICES: Set[str] = {"cuda"}
  144. CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
  145. SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}