setup.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. # Copyright (c) 2023, Tri Dao.
  2. import sys
  3. import functools
  4. import warnings
  5. import os
  6. import re
  7. import ast
  8. import glob
  9. import shutil
  10. from pathlib import Path
  11. from packaging.version import parse, Version
  12. import platform
  13. from setuptools import setup, find_packages
  14. import subprocess
  15. import urllib.request
  16. import urllib.error
  17. from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
  18. import torch
  19. from torch.utils.cpp_extension import (
  20. BuildExtension,
  21. CppExtension,
  22. CUDAExtension,
  23. CUDA_HOME,
  24. ROCM_HOME,
  25. IS_HIP_EXTENSION,
  26. )
  27. with open("README.md", "r", encoding="utf-8") as fh:
  28. long_description = fh.read()
  29. # ninja build does not work unless include_dirs are abs path
  30. this_dir = os.path.dirname(os.path.abspath(__file__))
  31. BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
  32. if BUILD_TARGET == "auto":
  33. if IS_HIP_EXTENSION:
  34. IS_ROCM = True
  35. else:
  36. IS_ROCM = False
  37. else:
  38. if BUILD_TARGET == "cuda":
  39. IS_ROCM = False
  40. elif BUILD_TARGET == "rocm":
  41. IS_ROCM = True
  42. PACKAGE_NAME = "flash_attn"
  43. BASE_WHEEL_URL = (
  44. "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
  45. )
  46. # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
  47. # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
  48. FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
  49. SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
  50. # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
  51. FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
  52. USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
  53. @functools.lru_cache(maxsize=None)
  54. def cuda_archs() -> str:
  55. return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";")
  56. def get_platform():
  57. """
  58. Returns the platform name as used in wheel filenames.
  59. """
  60. if sys.platform.startswith("linux"):
  61. return f'linux_{platform.uname().machine}'
  62. elif sys.platform == "darwin":
  63. mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
  64. return f"macosx_{mac_version}_x86_64"
  65. elif sys.platform == "win32":
  66. return "win_amd64"
  67. else:
  68. raise ValueError("Unsupported platform: {}".format(sys.platform))
  69. def get_cuda_bare_metal_version(cuda_dir):
  70. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  71. output = raw_output.split()
  72. release_idx = output.index("release") + 1
  73. bare_metal_version = parse(output[release_idx].split(",")[0])
  74. return raw_output, bare_metal_version
  75. def get_hip_version():
  76. return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
  77. def check_if_cuda_home_none(global_option: str) -> None:
  78. if CUDA_HOME is not None:
  79. return
  80. # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
  81. # in that case.
  82. warnings.warn(
  83. f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
  84. "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
  85. "only images whose names contain 'devel' will provide nvcc."
  86. )
  87. def check_if_rocm_home_none(global_option: str) -> None:
  88. if ROCM_HOME is not None:
  89. return
  90. # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
  91. # in that case.
  92. warnings.warn(
  93. f"{global_option} was requested, but hipcc was not found."
  94. )
  95. def append_nvcc_threads(nvcc_extra_args):
  96. nvcc_threads = os.getenv("NVCC_THREADS") or "2"
  97. return nvcc_extra_args + ["--threads", nvcc_threads]
  98. def rename_cpp_to_cu(cpp_files):
  99. for entry in cpp_files:
  100. shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
  101. def validate_and_update_archs(archs):
  102. # List of allowed architectures
  103. allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
  104. # Validate if each element in archs is in allowed_archs
  105. assert all(
  106. arch in allowed_archs for arch in archs
  107. ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
  108. cmdclass = {}
  109. ext_modules = []
  110. # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
  111. # files included in the source distribution, in case the user compiles from source.
  112. if os.path.isdir(".git"):
  113. subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True)
  114. subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True)
  115. else:
  116. if IS_ROCM:
  117. if not USE_TRITON_ROCM:
  118. assert (
  119. os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py")
  120. ), "csrc/composable_kernel is missing, please use source distribution or git clone"
  121. else:
  122. assert (
  123. os.path.exists("csrc/cutlass/include/cutlass/cutlass.h")
  124. ), "csrc/cutlass is missing, please use source distribution or git clone"
  125. if not SKIP_CUDA_BUILD and not IS_ROCM:
  126. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  127. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  128. TORCH_MINOR = int(torch.__version__.split(".")[1])
  129. check_if_cuda_home_none("flash_attn")
  130. # Check, if CUDA11 is installed for compute capability 8.0
  131. cc_flag = []
  132. if CUDA_HOME is not None:
  133. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  134. if bare_metal_version < Version("11.7"):
  135. raise RuntimeError(
  136. "FlashAttention is only supported on CUDA 11.7 and above. "
  137. "Note: make sure nvcc has a supported version by running nvcc -V."
  138. )
  139. if "80" in cuda_archs():
  140. cc_flag.append("-gencode")
  141. cc_flag.append("arch=compute_80,code=sm_80")
  142. if CUDA_HOME is not None:
  143. if bare_metal_version >= Version("11.8") and "90" in cuda_archs():
  144. cc_flag.append("-gencode")
  145. cc_flag.append("arch=compute_90,code=sm_90")
  146. if bare_metal_version >= Version("12.8") and "100" in cuda_archs():
  147. cc_flag.append("-gencode")
  148. cc_flag.append("arch=compute_100,code=sm_100")
  149. if bare_metal_version >= Version("12.8") and "120" in cuda_archs():
  150. cc_flag.append("-gencode")
  151. cc_flag.append("arch=compute_120,code=sm_120")
  152. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  153. # torch._C._GLIBCXX_USE_CXX11_ABI
  154. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  155. if FORCE_CXX11_ABI:
  156. torch._C._GLIBCXX_USE_CXX11_ABI = True
  157. ext_modules.append(
  158. CUDAExtension(
  159. name="flash_attn_2_cuda",
  160. sources=[
  161. "csrc/flash_attn/flash_api.cpp",
  162. "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
  163. "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
  164. "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
  165. "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
  166. "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
  167. "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
  168. "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
  169. "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
  170. "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
  171. "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
  172. "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
  173. "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
  174. "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
  175. "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
  176. "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
  177. "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
  178. "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
  179. "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
  180. "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
  181. "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
  182. "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
  183. "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
  184. "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
  185. "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
  186. "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
  187. "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
  188. "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
  189. "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
  190. "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
  191. "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
  192. "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
  193. "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
  194. "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
  195. "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
  196. "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
  197. "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
  198. "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
  199. "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
  200. "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
  201. "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
  202. "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
  203. "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
  204. "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
  205. "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
  206. "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
  207. "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
  208. "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
  209. "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
  210. "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
  211. "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
  212. "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
  213. "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
  214. "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
  215. "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
  216. "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
  217. "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
  218. "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
  219. "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
  220. "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
  221. "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
  222. "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
  223. "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
  224. "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
  225. "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
  226. "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
  227. "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
  228. "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
  229. "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
  230. "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
  231. "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
  232. "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
  233. "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
  234. "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
  235. "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
  236. "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
  237. "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
  238. "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
  239. "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
  240. "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
  241. "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
  242. "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
  243. "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
  244. "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
  245. "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
  246. ],
  247. extra_compile_args={
  248. "cxx": ["-O3", "-std=c++17"],
  249. "nvcc": append_nvcc_threads(
  250. [
  251. "-O3",
  252. "-std=c++17",
  253. "-U__CUDA_NO_HALF_OPERATORS__",
  254. "-U__CUDA_NO_HALF_CONVERSIONS__",
  255. "-U__CUDA_NO_HALF2_OPERATORS__",
  256. "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
  257. "--expt-relaxed-constexpr",
  258. "--expt-extended-lambda",
  259. "--use_fast_math",
  260. # "--ptxas-options=-v",
  261. # "--ptxas-options=-O2",
  262. # "-lineinfo",
  263. # "-DFLASHATTENTION_DISABLE_BACKWARD",
  264. # "-DFLASHATTENTION_DISABLE_DROPOUT",
  265. # "-DFLASHATTENTION_DISABLE_ALIBI",
  266. # "-DFLASHATTENTION_DISABLE_SOFTCAP",
  267. # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
  268. # "-DFLASHATTENTION_DISABLE_LOCAL",
  269. ]
  270. + cc_flag
  271. ),
  272. },
  273. include_dirs=[
  274. Path(this_dir) / "csrc" / "flash_attn",
  275. Path(this_dir) / "csrc" / "flash_attn" / "src",
  276. Path(this_dir) / "csrc" / "cutlass" / "include",
  277. ],
  278. )
  279. )
  280. elif not SKIP_CUDA_BUILD and IS_ROCM:
  281. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  282. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  283. TORCH_MINOR = int(torch.__version__.split(".")[1])
  284. if USE_TRITON_ROCM:
  285. # Skip C++ extension compilation if using Triton Backend
  286. pass
  287. else:
  288. ck_dir = "csrc/composable_kernel"
  289. #use codegen get code dispatch
  290. if not os.path.exists("./build"):
  291. os.makedirs("build")
  292. subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2"], check=True)
  293. subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2"], check=True)
  294. subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2"], check=True)
  295. subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2"], check=True)
  296. # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
  297. # See https://github.com/pytorch/pytorch/pull/70650
  298. generator_flag = []
  299. torch_dir = torch.__path__[0]
  300. if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
  301. generator_flag = ["-DOLD_GENERATOR_PATH"]
  302. check_if_rocm_home_none("flash_attn")
  303. archs = os.getenv("GPU_ARCHS", "native").split(";")
  304. validate_and_update_archs(archs)
  305. cc_flag = [f"--offload-arch={arch}" for arch in archs]
  306. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  307. # torch._C._GLIBCXX_USE_CXX11_ABI
  308. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  309. if FORCE_CXX11_ABI:
  310. torch._C._GLIBCXX_USE_CXX11_ABI = True
  311. sources = ["csrc/flash_attn_ck/flash_api.cpp",
  312. "csrc/flash_attn_ck/flash_common.cpp",
  313. "csrc/flash_attn_ck/mha_bwd.cpp",
  314. "csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
  315. "csrc/flash_attn_ck/mha_fwd.cpp",
  316. "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
  317. "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
  318. f"build/fmha_*wd*.cpp"
  319. )
  320. rename_cpp_to_cu(sources)
  321. renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
  322. "csrc/flash_attn_ck/flash_common.cu",
  323. "csrc/flash_attn_ck/mha_bwd.cu",
  324. "csrc/flash_attn_ck/mha_fwd_kvcache.cu",
  325. "csrc/flash_attn_ck/mha_fwd.cu",
  326. "csrc/flash_attn_ck/mha_varlen_bwd.cu",
  327. "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
  328. cc_flag += ["-O3","-std=c++17",
  329. "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
  330. "-fgpu-flush-denormals-to-zero",
  331. "-DCK_ENABLE_BF16",
  332. "-DCK_ENABLE_BF8",
  333. "-DCK_ENABLE_FP16",
  334. "-DCK_ENABLE_FP32",
  335. "-DCK_ENABLE_FP64",
  336. "-DCK_ENABLE_FP8",
  337. "-DCK_ENABLE_INT8",
  338. "-DCK_USE_XDL",
  339. "-DUSE_PROF_API=1",
  340. # "-DFLASHATTENTION_DISABLE_BACKWARD",
  341. "-D__HIP_PLATFORM_HCC__=1"]
  342. cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"]
  343. # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
  344. hip_version = get_hip_version()
  345. if hip_version > Version('5.7.23302'):
  346. cc_flag += ["-fno-offload-uniform-block"]
  347. if hip_version > Version('6.1.40090'):
  348. cc_flag += ["-mllvm", "-enable-post-misched=0"]
  349. if hip_version > Version('6.2.41132'):
  350. cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true",
  351. "-mllvm", "-amdgpu-function-calls=false"]
  352. if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):
  353. cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"]
  354. extra_compile_args = {
  355. "cxx": ["-O3", "-std=c++17"] + generator_flag,
  356. "nvcc": cc_flag + generator_flag,
  357. }
  358. include_dirs = [
  359. Path(this_dir) / "csrc" / "composable_kernel" / "include",
  360. Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
  361. Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
  362. ]
  363. ext_modules.append(
  364. CUDAExtension(
  365. name="flash_attn_2_cuda",
  366. sources=renamed_sources,
  367. extra_compile_args=extra_compile_args,
  368. include_dirs=include_dirs,
  369. )
  370. )
  371. def get_package_version():
  372. with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
  373. version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
  374. public_version = ast.literal_eval(version_match.group(1))
  375. local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
  376. if local_version:
  377. return f"{public_version}+{local_version}"
  378. else:
  379. return str(public_version)
  380. def get_wheel_url():
  381. torch_version_raw = parse(torch.__version__)
  382. python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
  383. platform_name = get_platform()
  384. flash_version = get_package_version()
  385. torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
  386. cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
  387. if IS_ROCM:
  388. torch_hip_version = get_hip_version()
  389. hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
  390. wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
  391. else:
  392. # Determine the version numbers that will be used to determine the correct wheel
  393. # We're using the CUDA version used to build torch, not the one currently installed
  394. # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
  395. torch_cuda_version = parse(torch.version.cuda)
  396. # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
  397. # to save CI time. Minor versions should be compatible.
  398. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
  399. # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
  400. cuda_version = f"{torch_cuda_version.major}"
  401. # Determine wheel URL based on CUDA version, torch version, python version and OS
  402. wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
  403. wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
  404. return wheel_url, wheel_filename
  405. class CachedWheelsCommand(_bdist_wheel):
  406. """
  407. The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
  408. find an existing wheel (which is currently the case for all flash attention installs). We use
  409. the environment parameters to detect whether there is already a pre-built version of a compatible
  410. wheel available and short-circuits the standard full build pipeline.
  411. """
  412. def run(self):
  413. if FORCE_BUILD:
  414. return super().run()
  415. wheel_url, wheel_filename = get_wheel_url()
  416. print("Guessing wheel URL: ", wheel_url)
  417. try:
  418. urllib.request.urlretrieve(wheel_url, wheel_filename)
  419. # Make the archive
  420. # Lifted from the root wheel processing command
  421. # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
  422. if not os.path.exists(self.dist_dir):
  423. os.makedirs(self.dist_dir)
  424. impl_tag, abi_tag, plat_tag = self.get_tag()
  425. archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
  426. wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
  427. print("Raw wheel path", wheel_path)
  428. os.rename(wheel_filename, wheel_path)
  429. except (urllib.error.HTTPError, urllib.error.URLError):
  430. print("Precompiled wheel not found. Building from source...")
  431. # If the wheel could not be downloaded, build from source
  432. super().run()
  433. class NinjaBuildExtension(BuildExtension):
  434. def __init__(self, *args, **kwargs) -> None:
  435. # do not override env MAX_JOBS if already exists
  436. if not os.environ.get("MAX_JOBS"):
  437. import psutil
  438. # calculate the maximum allowed NUM_JOBS based on cores
  439. max_num_jobs_cores = max(1, os.cpu_count() // 2)
  440. # calculate the maximum allowed NUM_JOBS based on free memory
  441. free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
  442. max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
  443. # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
  444. max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
  445. os.environ["MAX_JOBS"] = str(max_jobs)
  446. super().__init__(*args, **kwargs)
  447. setup(
  448. name=PACKAGE_NAME,
  449. version=get_package_version(),
  450. packages=find_packages(
  451. exclude=(
  452. "build",
  453. "csrc",
  454. "include",
  455. "tests",
  456. "dist",
  457. "docs",
  458. "benchmarks",
  459. "flash_attn.egg-info",
  460. )
  461. ),
  462. author="Tri Dao",
  463. author_email="tri@tridao.me",
  464. description="Flash Attention: Fast and Memory-Efficient Exact Attention",
  465. long_description=long_description,
  466. long_description_content_type="text/markdown",
  467. url="https://github.com/Dao-AILab/flash-attention",
  468. classifiers=[
  469. "Programming Language :: Python :: 3",
  470. "License :: OSI Approved :: BSD License",
  471. "Operating System :: Unix",
  472. ],
  473. ext_modules=ext_modules,
  474. cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
  475. if ext_modules
  476. else {
  477. "bdist_wheel": CachedWheelsCommand,
  478. },
  479. python_requires=">=3.9",
  480. install_requires=[
  481. "torch",
  482. "einops",
  483. ],
  484. setup_requires=[
  485. "packaging",
  486. "psutil",
  487. "ninja",
  488. ],
  489. )