setup.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. # Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  2. import sys
  3. import warnings
  4. import os
  5. import stat
  6. import re
  7. import shutil
  8. import ast
  9. from pathlib import Path
  10. from packaging.version import parse, Version
  11. import platform
  12. import sysconfig
  13. import tarfile
  14. import itertools
  15. from setuptools import setup, find_packages
  16. import subprocess
  17. import urllib.request
  18. import urllib.error
  19. from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
  20. import torch
  21. from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
  22. # with open("../README.md", "r", encoding="utf-8") as fh:
  23. with open("../README.md", "r", encoding="utf-8") as fh:
  24. long_description = fh.read()
  25. # ninja build does not work unless include_dirs are abs path
  26. this_dir = os.path.dirname(os.path.abspath(__file__))
  27. PACKAGE_NAME = "flash_attn_3"
  28. BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
  29. # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
  30. # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
  31. FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
  32. SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
  33. # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
  34. FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
  35. DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
  36. DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
  37. DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
  38. DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
  39. DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
  40. DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
  41. DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
  42. DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE"
  43. DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE"
  44. DISABLE_VARLEN = os.getenv("FLASH_ATTENTION_DISABLE_VARLEN", "FALSE") == "TRUE"
  45. DISABLE_CLUSTER = os.getenv("FLASH_ATTENTION_DISABLE_CLUSTER", "FALSE") == "TRUE"
  46. DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE"
  47. DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE"
  48. DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE"
  49. DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE"
  50. DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE"
  51. DISABLE_SM8x = os.getenv("FLASH_ATTENTION_DISABLE_SM80", "FALSE") == "TRUE"
  52. ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE"
  53. # HACK: we monkey patch pytorch's _write_ninja_file to pass
  54. # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu',
  55. # and pass "-gencode arch=compute_sm80,code=sm_80" to files ending in '_sm80.cu'
  56. from torch.utils.cpp_extension import (
  57. IS_HIP_EXTENSION,
  58. COMMON_HIP_FLAGS,
  59. SUBPROCESS_DECODE_ARGS,
  60. IS_WINDOWS,
  61. get_cxx_compiler,
  62. _join_rocm_home,
  63. _join_cuda_home,
  64. _is_cuda_file,
  65. _maybe_write,
  66. )
  67. def _write_ninja_file(path,
  68. cflags,
  69. post_cflags,
  70. cuda_cflags,
  71. cuda_post_cflags,
  72. cuda_dlink_post_cflags,
  73. sources,
  74. objects,
  75. ldflags,
  76. library_target,
  77. with_cuda,
  78. **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension
  79. ) -> None:
  80. r"""Write a ninja file that does the desired compiling and linking.
  81. `path`: Where to write this file
  82. `cflags`: list of flags to pass to $cxx. Can be None.
  83. `post_cflags`: list of flags to append to the $cxx invocation. Can be None.
  84. `cuda_cflags`: list of flags to pass to $nvcc. Can be None.
  85. `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
  86. `sources`: list of paths to source files
  87. `objects`: list of desired paths to objects, one per source.
  88. `ldflags`: list of flags to pass to linker. Can be None.
  89. `library_target`: Name of the output library. Can be None; in that case,
  90. we do no linking.
  91. `with_cuda`: If we should be compiling with CUDA.
  92. """
  93. def sanitize_flags(flags):
  94. if flags is None:
  95. return []
  96. else:
  97. return [flag.strip() for flag in flags]
  98. cflags = sanitize_flags(cflags)
  99. post_cflags = sanitize_flags(post_cflags)
  100. cuda_cflags = sanitize_flags(cuda_cflags)
  101. cuda_post_cflags = sanitize_flags(cuda_post_cflags)
  102. cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
  103. ldflags = sanitize_flags(ldflags)
  104. # Sanity checks...
  105. assert len(sources) == len(objects)
  106. assert len(sources) > 0
  107. compiler = get_cxx_compiler()
  108. # Version 1.3 is required for the `deps` directive.
  109. config = ['ninja_required_version = 1.3']
  110. config.append(f'cxx = {compiler}')
  111. if with_cuda or cuda_dlink_post_cflags:
  112. if IS_HIP_EXTENSION:
  113. nvcc = _join_rocm_home('bin', 'hipcc')
  114. else:
  115. nvcc = _join_cuda_home('bin', 'nvcc')
  116. if "PYTORCH_NVCC" in os.environ:
  117. nvcc_from_env = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here
  118. else:
  119. nvcc_from_env = nvcc
  120. config.append(f'nvcc_from_env = {nvcc_from_env}')
  121. config.append(f'nvcc = {nvcc}')
  122. if IS_HIP_EXTENSION:
  123. post_cflags = COMMON_HIP_FLAGS + post_cflags
  124. flags = [f'cflags = {" ".join(cflags)}']
  125. flags.append(f'post_cflags = {" ".join(post_cflags)}')
  126. if with_cuda:
  127. flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
  128. flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
  129. cuda_post_cflags_sm80 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_80,code=sm_80' for s in cuda_post_cflags]
  130. flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}')
  131. cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80']
  132. flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}')
  133. cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags]
  134. flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}')
  135. flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
  136. flags.append(f'ldflags = {" ".join(ldflags)}')
  137. # Turn into absolute paths so we can emit them into the ninja build
  138. # file wherever it is.
  139. sources = [os.path.abspath(file) for file in sources]
  140. # See https://ninja-build.org/build.ninja.html for reference.
  141. compile_rule = ['rule compile']
  142. if IS_WINDOWS:
  143. compile_rule.append(
  144. ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
  145. compile_rule.append(' deps = msvc')
  146. else:
  147. compile_rule.append(
  148. ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')
  149. compile_rule.append(' depfile = $out.d')
  150. compile_rule.append(' deps = gcc')
  151. if with_cuda:
  152. cuda_compile_rule = ['rule cuda_compile']
  153. nvcc_gendeps = ''
  154. # --generate-dependencies-with-compile is not supported by ROCm
  155. # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time.
  156. if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1':
  157. cuda_compile_rule.append(' depfile = $out.d')
  158. cuda_compile_rule.append(' deps = gcc')
  159. # Note: non-system deps with nvcc are only supported
  160. # on Linux so use --generate-dependencies-with-compile
  161. # to make this work on Windows too.
  162. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d'
  163. cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [
  164. f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80'
  165. ]
  166. cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [
  167. f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90'
  168. ]
  169. cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [
  170. f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100'
  171. ]
  172. cuda_compile_rule.append(
  173. f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')
  174. # Emit one build rule per source to enable incremental build.
  175. build = []
  176. for source_file, object_file in zip(sources, objects):
  177. is_cuda_source = _is_cuda_file(source_file) and with_cuda
  178. if is_cuda_source:
  179. if source_file.endswith('_sm90.cu'):
  180. rule = 'cuda_compile'
  181. elif source_file.endswith('_sm80.cu'):
  182. rule = 'cuda_compile_sm80'
  183. elif source_file.endswith('_sm100.cu'):
  184. rule = 'cuda_compile_sm100'
  185. else:
  186. rule = 'cuda_compile_sm80_sm90'
  187. else:
  188. rule = 'compile'
  189. if IS_WINDOWS:
  190. source_file = source_file.replace(':', '$:')
  191. object_file = object_file.replace(':', '$:')
  192. source_file = source_file.replace(" ", "$ ")
  193. object_file = object_file.replace(" ", "$ ")
  194. build.append(f'build {object_file}: {rule} {source_file}')
  195. if cuda_dlink_post_cflags:
  196. devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
  197. devlink_rule = ['rule cuda_devlink']
  198. devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
  199. devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
  200. objects += [devlink_out]
  201. else:
  202. devlink_rule, devlink = [], []
  203. if library_target is not None:
  204. link_rule = ['rule link']
  205. if IS_WINDOWS:
  206. cl_paths = subprocess.check_output(['where',
  207. 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n')
  208. if len(cl_paths) >= 1:
  209. cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
  210. else:
  211. raise RuntimeError("MSVC is required to load C++ extensions")
  212. link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')
  213. else:
  214. link_rule.append(' command = $cxx $in $ldflags -o $out')
  215. link = [f'build {library_target}: link {" ".join(objects)}']
  216. default = [f'default {library_target}']
  217. else:
  218. link_rule, link, default = [], [], []
  219. # 'Blocks' should be separated by newlines, for visual benefit.
  220. blocks = [config, flags, compile_rule]
  221. if with_cuda:
  222. blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
  223. blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined]
  224. blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined]
  225. blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined]
  226. blocks += [devlink_rule, link_rule, build, devlink, link, default]
  227. content = "\n\n".join("\n".join(b) for b in blocks)
  228. # Ninja requires a new lines at the end of the .ninja file
  229. content += "\n"
  230. _maybe_write(path, content)
  231. # Monkey patching
  232. torch.utils.cpp_extension._write_ninja_file = _write_ninja_file
  233. def get_platform():
  234. """
  235. Returns the platform name as used in wheel filenames.
  236. """
  237. if sys.platform.startswith("linux"):
  238. return "linux_x86_64"
  239. elif sys.platform == "darwin":
  240. mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
  241. return f"macosx_{mac_version}_x86_64"
  242. elif sys.platform == "win32":
  243. return "win_amd64"
  244. else:
  245. raise ValueError("Unsupported platform: {}".format(sys.platform))
  246. def get_cuda_bare_metal_version(cuda_dir):
  247. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  248. output = raw_output.split()
  249. release_idx = output.index("release") + 1
  250. bare_metal_version = parse(output[release_idx].split(",")[0])
  251. return raw_output, bare_metal_version
  252. def check_if_cuda_home_none(global_option: str) -> None:
  253. if CUDA_HOME is not None:
  254. return
  255. # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
  256. # in that case.
  257. warnings.warn(
  258. f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
  259. "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
  260. "only images whose names contain 'devel' will provide nvcc."
  261. )
  262. # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
  263. def check_env_flag(name: str, default: str = "") -> bool:
  264. return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
  265. # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py
  266. def is_offline_build() -> bool:
  267. """
  268. Downstream projects and distributions which bootstrap their own dependencies from scratch
  269. and run builds in offline sandboxes
  270. may set `FLASH_ATTENTION_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading
  271. pinned dependencies from the internet or at using dependencies vendored in-tree.
  272. Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).
  273. Missing dependencies lead to an early abortion.
  274. Dependencies' compatibility is not verified.
  275. Note that this flag isn't tested by the CI and does not provide any guarantees.
  276. """
  277. return check_env_flag("FLASH_ATTENTION_OFFLINE_BUILD", "")
  278. # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py
  279. def get_flashattn_cache_path():
  280. user_home = os.getenv("FLASH_ATTENTION_HOME")
  281. if not user_home:
  282. user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None
  283. if not user_home:
  284. raise RuntimeError("Could not find user home directory")
  285. return os.path.join(user_home, ".flashattn")
  286. def open_url(url):
  287. user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'
  288. headers = {
  289. 'User-Agent': user_agent,
  290. }
  291. request = urllib.request.Request(url, None, headers)
  292. # Set timeout to 300 seconds to prevent the request from hanging forever.
  293. return urllib.request.urlopen(request, timeout=300)
  294. def download_and_copy(name, src_func, dst_path, version, url_func):
  295. if is_offline_build():
  296. return
  297. flashattn_cache_path = get_flashattn_cache_path()
  298. base_dir = os.path.dirname(__file__)
  299. system = platform.system()
  300. arch = platform.machine()
  301. arch = {"arm64": "aarch64"}.get(arch, arch)
  302. supported = {"Linux": "linux", "Darwin": "linux"}
  303. url = url_func(supported[system], arch, version)
  304. src_path = src_func(supported[system], arch, version)
  305. tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download
  306. dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
  307. src_path = os.path.join(tmp_path, src_path)
  308. download = not os.path.exists(src_path)
  309. if download:
  310. print(f'downloading and extracting {url} ...')
  311. file = tarfile.open(fileobj=open_url(url), mode="r|*")
  312. file.extractall(path=tmp_path)
  313. os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
  314. print(f'copy {src_path} to {dst_path} ...')
  315. if os.path.isdir(src_path):
  316. shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
  317. else:
  318. shutil.copy(src_path, dst_path)
  319. def nvcc_threads_args():
  320. nvcc_threads = os.getenv("NVCC_THREADS") or "2"
  321. return ["--threads", nvcc_threads]
  322. # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"}
  323. NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"}
  324. exe_extension = sysconfig.get_config_var("EXE")
  325. cmdclass = {}
  326. ext_modules = []
  327. # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
  328. # files included in the source distribution, in case the user compiles from source.
  329. subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"])
  330. if not SKIP_CUDA_BUILD:
  331. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  332. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  333. TORCH_MINOR = int(torch.__version__.split(".")[1])
  334. check_if_cuda_home_none(PACKAGE_NAME)
  335. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  336. if bare_metal_version < Version("12.3"):
  337. raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above")
  338. # ptxas 12.8 gives the best perf currently
  339. # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8
  340. # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have.
  341. if bare_metal_version != Version("12.8"):
  342. download_and_copy(
  343. name="nvcc",
  344. src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin",
  345. dst_path="bin",
  346. version=NVIDIA_TOOLCHAIN_VERSION["nvcc"],
  347. url_func=lambda system, arch, version:
  348. f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
  349. )
  350. download_and_copy(
  351. name="ptxas",
  352. src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas",
  353. dst_path="bin",
  354. version=NVIDIA_TOOLCHAIN_VERSION["ptxas"],
  355. url_func=lambda system, arch, version:
  356. f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
  357. )
  358. download_and_copy(
  359. name="ptxas",
  360. src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin",
  361. dst_path="nvvm/bin",
  362. version=NVIDIA_TOOLCHAIN_VERSION["ptxas"],
  363. url_func=lambda system, arch, version:
  364. f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
  365. )
  366. base_dir = os.path.dirname(__file__)
  367. ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")
  368. nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}")
  369. # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc
  370. # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc
  371. os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"]
  372. os.environ["PYTORCH_NVCC"] = nvcc_path_new
  373. # Make nvcc executable, sometimes after the copy it loses its permissions
  374. os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC)
  375. cc_flag = []
  376. cc_flag.append("-gencode")
  377. cc_flag.append("arch=compute_90a,code=sm_90a")
  378. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  379. # torch._C._GLIBCXX_USE_CXX11_ABI
  380. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  381. if FORCE_CXX11_ABI:
  382. torch._C._GLIBCXX_USE_CXX11_ABI = True
  383. repo_dir = Path(this_dir).parent
  384. cutlass_dir = repo_dir / "csrc" / "cutlass"
  385. feature_args = (
  386. []
  387. + (["-DFLASHATTENTION_DISABLE_BACKWARD"] if DISABLE_BACKWARD else [])
  388. + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else [])
  389. + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else [])
  390. + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else [])
  391. + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else [])
  392. + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else [])
  393. + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else [])
  394. + (["-DFLASHATTENTION_DISABLE_FP16"] if DISABLE_FP16 else [])
  395. + (["-DFLASHATTENTION_DISABLE_FP8"] if DISABLE_FP8 else [])
  396. + (["-DFLASHATTENTION_DISABLE_VARLEN"] if DISABLE_VARLEN else [])
  397. + (["-DFLASHATTENTION_DISABLE_CLUSTER"] if DISABLE_CLUSTER else [])
  398. + (["-DFLASHATTENTION_DISABLE_HDIM64"] if DISABLE_HDIM64 else [])
  399. + (["-DFLASHATTENTION_DISABLE_HDIM96"] if DISABLE_HDIM96 else [])
  400. + (["-DFLASHATTENTION_DISABLE_HDIM128"] if DISABLE_HDIM128 else [])
  401. + (["-DFLASHATTENTION_DISABLE_HDIM192"] if DISABLE_HDIM192 else [])
  402. + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else [])
  403. + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else [])
  404. + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else [])
  405. )
  406. DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else [])
  407. DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else [])
  408. DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else [])
  409. HEAD_DIMENSIONS_BWD = (
  410. []
  411. + ([64] if not DISABLE_HDIM64 else [])
  412. + ([96] if not DISABLE_HDIM96 else [])
  413. + ([128] if not DISABLE_HDIM128 else [])
  414. + ([192] if not DISABLE_HDIM192 else [])
  415. + ([256] if not DISABLE_HDIM256 else [])
  416. )
  417. HEAD_DIMENSIONS_FWD = ["all", "diff"]
  418. HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD
  419. SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else [])
  420. PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else [])
  421. SOFTCAP = [""] + (["_softcap"] if not DISABLE_SOFTCAP else [])
  422. SOFTCAP_ALL = [""] if DISABLE_SOFTCAP else ["_softcapall"]
  423. PACKGQA = [""] + (["_packgqa"] if not DISABLE_PACKGQA else [])
  424. # We already always hard-code PackGQA=true for Sm8x
  425. sources_fwd_sm80 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}_sm80.cu"
  426. for hdim, dtype, split, paged, softcap in itertools.product(HEAD_DIMENSIONS_FWD_SM80, DTYPE_FWD_SM80, SPLIT, PAGEDKV, SOFTCAP_ALL)]
  427. # We already always hard-code PackGQA=true for Sm9x if PagedKV or Split
  428. sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu"
  429. for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA)
  430. if not (packgqa and (paged or split))]
  431. sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu"
  432. for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)]
  433. sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu"
  434. for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP_ALL)]
  435. if DISABLE_BACKWARD:
  436. sources_bwd_sm90 = []
  437. sources_bwd_sm80 = []
  438. sources = (
  439. ["flash_api.cpp"]
  440. + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90
  441. + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90
  442. )
  443. if not DISABLE_SPLIT:
  444. sources += ["flash_fwd_combine.cu"]
  445. sources += ["flash_prepare_scheduler.cu"]
  446. nvcc_flags = [
  447. "-O3",
  448. "-std=c++17",
  449. "--ftemplate-backtrace-limit=0", # To debug template code
  450. "--use_fast_math",
  451. # "--keep",
  452. # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers
  453. "--resource-usage", # printing out number of registers
  454. # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster
  455. "-lineinfo",
  456. "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use
  457. # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL
  458. "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging
  459. "-DNDEBUG", # Important, otherwise performance is severely impacted
  460. ]
  461. if get_platform() == "win_amd64":
  462. nvcc_flags.extend(
  463. [
  464. "-D_USE_MATH_DEFINES", # for M_LN2
  465. "-Xcompiler=/Zc:__cplusplus", # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd
  466. ]
  467. )
  468. include_dirs = [
  469. Path(this_dir),
  470. cutlass_dir / "include",
  471. ]
  472. ext_modules.append(
  473. CUDAExtension(
  474. name="flash_attn_3_cuda",
  475. sources=sources,
  476. extra_compile_args={
  477. "cxx": ["-O3", "-std=c++17"] + feature_args,
  478. "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args,
  479. },
  480. include_dirs=include_dirs,
  481. )
  482. )
  483. def get_package_version():
  484. with open(Path(this_dir) / "__init__.py", "r") as f:
  485. version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
  486. public_version = ast.literal_eval(version_match.group(1))
  487. local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
  488. if local_version:
  489. return f"{public_version}+{local_version}"
  490. else:
  491. return str(public_version)
  492. def get_wheel_url():
  493. # Determine the version numbers that will be used to determine the correct wheel
  494. # We're using the CUDA version used to build torch, not the one currently installed
  495. # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
  496. torch_cuda_version = parse(torch.version.cuda)
  497. torch_version_raw = parse(torch.__version__)
  498. # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
  499. # to save CI time. Minor versions should be compatible.
  500. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
  501. python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
  502. platform_name = get_platform()
  503. package_version = get_package_version()
  504. # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
  505. cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
  506. torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
  507. cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
  508. # Determine wheel URL based on CUDA version, torch version, python version and OS
  509. wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
  510. wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename)
  511. return wheel_url, wheel_filename
  512. class CachedWheelsCommand(_bdist_wheel):
  513. """
  514. The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
  515. find an existing wheel (which is currently the case for all installs). We use
  516. the environment parameters to detect whether there is already a pre-built version of a compatible
  517. wheel available and short-circuits the standard full build pipeline.
  518. """
  519. def run(self):
  520. if FORCE_BUILD:
  521. return super().run()
  522. wheel_url, wheel_filename = get_wheel_url()
  523. print("Guessing wheel URL: ", wheel_url)
  524. try:
  525. urllib.request.urlretrieve(wheel_url, wheel_filename)
  526. # Make the archive
  527. # Lifted from the root wheel processing command
  528. # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
  529. if not os.path.exists(self.dist_dir):
  530. os.makedirs(self.dist_dir)
  531. impl_tag, abi_tag, plat_tag = self.get_tag()
  532. archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
  533. wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
  534. print("Raw wheel path", wheel_path)
  535. shutil.move(wheel_filename, wheel_path)
  536. except urllib.error.HTTPError:
  537. print("Precompiled wheel not found. Building from source...")
  538. # If the wheel could not be downloaded, build from source
  539. super().run()
  540. setup(
  541. name=PACKAGE_NAME,
  542. version=get_package_version(),
  543. packages=find_packages(
  544. exclude=(
  545. "build",
  546. "csrc",
  547. "include",
  548. "tests",
  549. "dist",
  550. "docs",
  551. "benchmarks",
  552. )
  553. ),
  554. py_modules=["flash_attn_interface"],
  555. description="FlashAttention-3",
  556. long_description=long_description,
  557. long_description_content_type="text/markdown",
  558. classifiers=[
  559. "Programming Language :: Python :: 3",
  560. "License :: OSI Approved :: Apache Software License",
  561. "Operating System :: Unix",
  562. ],
  563. ext_modules=ext_modules,
  564. cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
  565. if ext_modules
  566. else {
  567. "bdist_wheel": CachedWheelsCommand,
  568. },
  569. python_requires=">=3.8",
  570. install_requires=[
  571. "torch",
  572. "einops",
  573. "packaging",
  574. "ninja",
  575. ],
  576. )