setup.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. # Copyright (c) 2023, Tri Dao.
  2. import sys
  3. import warnings
  4. import os
  5. import re
  6. import ast
  7. import glob
  8. import shutil
  9. from pathlib import Path
  10. from packaging.version import parse, Version
  11. import platform
  12. from setuptools import setup, find_packages
  13. import subprocess
  14. import urllib.request
  15. import urllib.error
  16. from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
  17. import torch
  18. from torch.utils.cpp_extension import (
  19. BuildExtension,
  20. CppExtension,
  21. CUDAExtension,
  22. CUDA_HOME,
  23. ROCM_HOME,
  24. IS_HIP_EXTENSION,
  25. )
  26. with open("README.md", "r", encoding="utf-8") as fh:
  27. long_description = fh.read()
  28. # ninja build does not work unless include_dirs are abs path
  29. this_dir = os.path.dirname(os.path.abspath(__file__))
  30. BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
  31. if BUILD_TARGET == "auto":
  32. if IS_HIP_EXTENSION:
  33. IS_ROCM = True
  34. else:
  35. IS_ROCM = False
  36. else:
  37. if BUILD_TARGET == "cuda":
  38. IS_ROCM = False
  39. elif BUILD_TARGET == "rocm":
  40. IS_ROCM = True
  41. PACKAGE_NAME = "flash_attn"
  42. BASE_WHEEL_URL = (
  43. "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
  44. )
  45. # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
  46. # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
  47. FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
  48. SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
  49. # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
  50. FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
  51. USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
  52. def get_platform():
  53. """
  54. Returns the platform name as used in wheel filenames.
  55. """
  56. if sys.platform.startswith("linux"):
  57. return f'linux_{platform.uname().machine}'
  58. elif sys.platform == "darwin":
  59. mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
  60. return f"macosx_{mac_version}_x86_64"
  61. elif sys.platform == "win32":
  62. return "win_amd64"
  63. else:
  64. raise ValueError("Unsupported platform: {}".format(sys.platform))
  65. def get_cuda_bare_metal_version(cuda_dir):
  66. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  67. output = raw_output.split()
  68. release_idx = output.index("release") + 1
  69. bare_metal_version = parse(output[release_idx].split(",")[0])
  70. return raw_output, bare_metal_version
  71. def get_hip_version():
  72. return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
  73. def check_if_cuda_home_none(global_option: str) -> None:
  74. if CUDA_HOME is not None:
  75. return
  76. # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
  77. # in that case.
  78. warnings.warn(
  79. f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
  80. "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
  81. "only images whose names contain 'devel' will provide nvcc."
  82. )
  83. def check_if_rocm_home_none(global_option: str) -> None:
  84. if ROCM_HOME is not None:
  85. return
  86. # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
  87. # in that case.
  88. warnings.warn(
  89. f"{global_option} was requested, but hipcc was not found."
  90. )
  91. def append_nvcc_threads(nvcc_extra_args):
  92. nvcc_threads = os.getenv("NVCC_THREADS") or "2"
  93. return nvcc_extra_args + ["--threads", nvcc_threads]
  94. def rename_cpp_to_cu(cpp_files):
  95. for entry in cpp_files:
  96. shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
  97. def validate_and_update_archs(archs):
  98. # List of allowed architectures
  99. allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
  100. # Validate if each element in archs is in allowed_archs
  101. assert all(
  102. arch in allowed_archs for arch in archs
  103. ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
  104. cmdclass = {}
  105. ext_modules = []
  106. # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
  107. # files included in the source distribution, in case the user compiles from source.
  108. if IS_ROCM:
  109. if not USE_TRITON_ROCM:
  110. subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
  111. else:
  112. subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
  113. if not SKIP_CUDA_BUILD and not IS_ROCM:
  114. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  115. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  116. TORCH_MINOR = int(torch.__version__.split(".")[1])
  117. check_if_cuda_home_none("flash_attn")
  118. # Check, if CUDA11 is installed for compute capability 8.0
  119. cc_flag = []
  120. if CUDA_HOME is not None:
  121. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  122. if bare_metal_version < Version("11.7"):
  123. raise RuntimeError(
  124. "FlashAttention is only supported on CUDA 11.7 and above. "
  125. "Note: make sure nvcc has a supported version by running nvcc -V."
  126. )
  127. # cc_flag.append("-gencode")
  128. # cc_flag.append("arch=compute_75,code=sm_75")
  129. cc_flag.append("-gencode")
  130. cc_flag.append("arch=compute_80,code=sm_80")
  131. if CUDA_HOME is not None:
  132. if bare_metal_version >= Version("11.8"):
  133. cc_flag.append("-gencode")
  134. cc_flag.append("arch=compute_90,code=sm_90")
  135. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  136. # torch._C._GLIBCXX_USE_CXX11_ABI
  137. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  138. if FORCE_CXX11_ABI:
  139. torch._C._GLIBCXX_USE_CXX11_ABI = True
  140. ext_modules.append(
  141. CUDAExtension(
  142. name="flash_attn_2_cuda",
  143. sources=[
  144. "csrc/flash_attn/flash_api.cpp",
  145. "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
  146. "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
  147. "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
  148. "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
  149. "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
  150. "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
  151. "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
  152. "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
  153. "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
  154. "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
  155. "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
  156. "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
  157. "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
  158. "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
  159. "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
  160. "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
  161. "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
  162. "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
  163. "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
  164. "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
  165. "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
  166. "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
  167. "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
  168. "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
  169. "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
  170. "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
  171. "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
  172. "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
  173. "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
  174. "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
  175. "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
  176. "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
  177. "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
  178. "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
  179. "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
  180. "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
  181. "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
  182. "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
  183. "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
  184. "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
  185. "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
  186. "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
  187. "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
  188. "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
  189. "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
  190. "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
  191. "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
  192. "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
  193. "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
  194. "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
  195. "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
  196. "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
  197. "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
  198. "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
  199. "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
  200. "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
  201. "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
  202. "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
  203. "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
  204. "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
  205. "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
  206. "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
  207. "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
  208. "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
  209. "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
  210. "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
  211. "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
  212. "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
  213. "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
  214. "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
  215. "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
  216. "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
  217. "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
  218. "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
  219. "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
  220. "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
  221. "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
  222. "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
  223. "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
  224. "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
  225. "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
  226. "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
  227. "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
  228. "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
  229. ],
  230. extra_compile_args={
  231. "cxx": ["-O3", "-std=c++17"],
  232. "nvcc": append_nvcc_threads(
  233. [
  234. "-O3",
  235. "-std=c++17",
  236. "-U__CUDA_NO_HALF_OPERATORS__",
  237. "-U__CUDA_NO_HALF_CONVERSIONS__",
  238. "-U__CUDA_NO_HALF2_OPERATORS__",
  239. "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
  240. "--expt-relaxed-constexpr",
  241. "--expt-extended-lambda",
  242. "--use_fast_math",
  243. # "--ptxas-options=-v",
  244. # "--ptxas-options=-O2",
  245. # "-lineinfo",
  246. # "-DFLASHATTENTION_DISABLE_BACKWARD",
  247. # "-DFLASHATTENTION_DISABLE_DROPOUT",
  248. # "-DFLASHATTENTION_DISABLE_ALIBI",
  249. # "-DFLASHATTENTION_DISABLE_SOFTCAP",
  250. # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
  251. # "-DFLASHATTENTION_DISABLE_LOCAL",
  252. ]
  253. + cc_flag
  254. ),
  255. },
  256. include_dirs=[
  257. Path(this_dir) / "csrc" / "flash_attn",
  258. Path(this_dir) / "csrc" / "flash_attn" / "src",
  259. Path(this_dir) / "csrc" / "cutlass" / "include",
  260. ],
  261. )
  262. )
  263. elif not SKIP_CUDA_BUILD and IS_ROCM:
  264. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  265. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  266. TORCH_MINOR = int(torch.__version__.split(".")[1])
  267. if USE_TRITON_ROCM:
  268. # Skip C++ extension compilation if using Triton Backend
  269. pass
  270. else:
  271. ck_dir = "csrc/composable_kernel"
  272. #use codegen get code dispatch
  273. if not os.path.exists("./build"):
  274. os.makedirs("build")
  275. os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
  276. os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2")
  277. os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2")
  278. os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
  279. # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
  280. # See https://github.com/pytorch/pytorch/pull/70650
  281. generator_flag = []
  282. torch_dir = torch.__path__[0]
  283. if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
  284. generator_flag = ["-DOLD_GENERATOR_PATH"]
  285. check_if_rocm_home_none("flash_attn")
  286. archs = os.getenv("GPU_ARCHS", "native").split(";")
  287. validate_and_update_archs(archs)
  288. cc_flag = [f"--offload-arch={arch}" for arch in archs]
  289. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  290. # torch._C._GLIBCXX_USE_CXX11_ABI
  291. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  292. if FORCE_CXX11_ABI:
  293. torch._C._GLIBCXX_USE_CXX11_ABI = True
  294. sources = ["csrc/flash_attn_ck/flash_api.cpp",
  295. "csrc/flash_attn_ck/flash_common.cpp",
  296. "csrc/flash_attn_ck/mha_bwd.cpp",
  297. "csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
  298. "csrc/flash_attn_ck/mha_fwd.cpp",
  299. "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
  300. "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
  301. f"build/fmha_*wd*.cpp"
  302. )
  303. rename_cpp_to_cu(sources)
  304. renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
  305. "csrc/flash_attn_ck/flash_common.cu",
  306. "csrc/flash_attn_ck/mha_bwd.cu",
  307. "csrc/flash_attn_ck/mha_fwd_kvcache.cu",
  308. "csrc/flash_attn_ck/mha_fwd.cu",
  309. "csrc/flash_attn_ck/mha_varlen_bwd.cu",
  310. "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
  311. cc_flag += ["-O3","-std=c++17",
  312. "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
  313. "-fgpu-flush-denormals-to-zero",
  314. "-DCK_ENABLE_BF16",
  315. "-DCK_ENABLE_BF8",
  316. "-DCK_ENABLE_FP16",
  317. "-DCK_ENABLE_FP32",
  318. "-DCK_ENABLE_FP64",
  319. "-DCK_ENABLE_FP8",
  320. "-DCK_ENABLE_INT8",
  321. "-DCK_USE_XDL",
  322. "-DUSE_PROF_API=1",
  323. # "-DFLASHATTENTION_DISABLE_BACKWARD",
  324. "-D__HIP_PLATFORM_HCC__=1"]
  325. cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"]
  326. # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
  327. hip_version = get_hip_version()
  328. if hip_version > Version('5.7.23302'):
  329. cc_flag += ["-fno-offload-uniform-block"]
  330. if hip_version > Version('6.1.40090'):
  331. cc_flag += ["-mllvm", "-enable-post-misched=0"]
  332. if hip_version > Version('6.2.41132'):
  333. cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true",
  334. "-mllvm", "-amdgpu-function-calls=false"]
  335. if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):
  336. cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"]
  337. extra_compile_args = {
  338. "cxx": ["-O3", "-std=c++17"] + generator_flag,
  339. "nvcc": cc_flag + generator_flag,
  340. }
  341. include_dirs = [
  342. Path(this_dir) / "csrc" / "composable_kernel" / "include",
  343. Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
  344. Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
  345. ]
  346. ext_modules.append(
  347. CUDAExtension(
  348. name="flash_attn_2_cuda",
  349. sources=renamed_sources,
  350. extra_compile_args=extra_compile_args,
  351. include_dirs=include_dirs,
  352. )
  353. )
  354. def get_package_version():
  355. with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
  356. version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
  357. public_version = ast.literal_eval(version_match.group(1))
  358. local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
  359. if local_version:
  360. return f"{public_version}+{local_version}"
  361. else:
  362. return str(public_version)
  363. def get_wheel_url():
  364. torch_version_raw = parse(torch.__version__)
  365. python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
  366. platform_name = get_platform()
  367. flash_version = get_package_version()
  368. torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
  369. cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
  370. if IS_ROCM:
  371. torch_hip_version = get_hip_version()
  372. hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
  373. wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
  374. else:
  375. # Determine the version numbers that will be used to determine the correct wheel
  376. # We're using the CUDA version used to build torch, not the one currently installed
  377. # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
  378. torch_cuda_version = parse(torch.version.cuda)
  379. # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
  380. # to save CI time. Minor versions should be compatible.
  381. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
  382. # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
  383. cuda_version = f"{torch_cuda_version.major}"
  384. # Determine wheel URL based on CUDA version, torch version, python version and OS
  385. wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
  386. wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
  387. return wheel_url, wheel_filename
  388. class CachedWheelsCommand(_bdist_wheel):
  389. """
  390. The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
  391. find an existing wheel (which is currently the case for all flash attention installs). We use
  392. the environment parameters to detect whether there is already a pre-built version of a compatible
  393. wheel available and short-circuits the standard full build pipeline.
  394. """
  395. def run(self):
  396. if FORCE_BUILD:
  397. return super().run()
  398. wheel_url, wheel_filename = get_wheel_url()
  399. print("Guessing wheel URL: ", wheel_url)
  400. try:
  401. urllib.request.urlretrieve(wheel_url, wheel_filename)
  402. # Make the archive
  403. # Lifted from the root wheel processing command
  404. # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
  405. if not os.path.exists(self.dist_dir):
  406. os.makedirs(self.dist_dir)
  407. impl_tag, abi_tag, plat_tag = self.get_tag()
  408. archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
  409. wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
  410. print("Raw wheel path", wheel_path)
  411. os.rename(wheel_filename, wheel_path)
  412. except (urllib.error.HTTPError, urllib.error.URLError):
  413. print("Precompiled wheel not found. Building from source...")
  414. # If the wheel could not be downloaded, build from source
  415. super().run()
  416. class NinjaBuildExtension(BuildExtension):
  417. def __init__(self, *args, **kwargs) -> None:
  418. # do not override env MAX_JOBS if already exists
  419. if not os.environ.get("MAX_JOBS"):
  420. import psutil
  421. # calculate the maximum allowed NUM_JOBS based on cores
  422. max_num_jobs_cores = max(1, os.cpu_count() // 2)
  423. # calculate the maximum allowed NUM_JOBS based on free memory
  424. free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
  425. max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
  426. # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
  427. max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
  428. os.environ["MAX_JOBS"] = str(max_jobs)
  429. super().__init__(*args, **kwargs)
  430. setup(
  431. name=PACKAGE_NAME,
  432. version=get_package_version(),
  433. packages=find_packages(
  434. exclude=(
  435. "build",
  436. "csrc",
  437. "include",
  438. "tests",
  439. "dist",
  440. "docs",
  441. "benchmarks",
  442. "flash_attn.egg-info",
  443. )
  444. ),
  445. author="Tri Dao",
  446. author_email="tri@tridao.me",
  447. description="Flash Attention: Fast and Memory-Efficient Exact Attention",
  448. long_description=long_description,
  449. long_description_content_type="text/markdown",
  450. url="https://github.com/Dao-AILab/flash-attention",
  451. classifiers=[
  452. "Programming Language :: Python :: 3",
  453. "License :: OSI Approved :: BSD License",
  454. "Operating System :: Unix",
  455. ],
  456. ext_modules=ext_modules,
  457. cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
  458. if ext_modules
  459. else {
  460. "bdist_wheel": CachedWheelsCommand,
  461. },
  462. python_requires=">=3.9",
  463. install_requires=[
  464. "torch",
  465. "einops",
  466. ],
  467. setup_requires=[
  468. "packaging",
  469. "psutil",
  470. "ninja",
  471. ],
  472. )