123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424 |
- # Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- import sys
- import warnings
- import os
- import stat
- import re
- import shutil
- import ast
- from pathlib import Path
- from packaging.version import parse, Version
- import platform
- import sysconfig
- import tarfile
- import itertools
- from setuptools import setup, find_packages
- import subprocess
- import urllib.request
- import urllib.error
- from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
- import torch
- from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
- # with open("../README.md", "r", encoding="utf-8") as fh:
- with open("../README.md", "r", encoding="utf-8") as fh:
- long_description = fh.read()
- # ninja build does not work unless include_dirs are abs path
- this_dir = os.path.dirname(os.path.abspath(__file__))
- PACKAGE_NAME = "flash_attn"
- BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
- # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
- # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
- FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
- SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
- # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
- FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
- DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
- DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
- DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
- DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
- DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
- DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
- DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
- DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE"
- DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE"
- DISABLE_VARLEN = os.getenv("FLASH_ATTENTION_DISABLE_VARLEN", "FALSE") == "TRUE"
- DISABLE_CLUSTER = os.getenv("FLASH_ATTENTION_DISABLE_CLUSTER", "FALSE") == "TRUE"
- DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE"
- DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE"
- DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE"
- DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE"
- DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE"
- ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE"
- def get_platform():
- """
- Returns the platform name as used in wheel filenames.
- """
- if sys.platform.startswith("linux"):
- return "linux_x86_64"
- elif sys.platform == "darwin":
- mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
- return f"macosx_{mac_version}_x86_64"
- elif sys.platform == "win32":
- return "win_amd64"
- else:
- raise ValueError("Unsupported platform: {}".format(sys.platform))
- def get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- bare_metal_version = parse(output[release_idx].split(",")[0])
- return raw_output, bare_metal_version
- def check_if_cuda_home_none(global_option: str) -> None:
- if CUDA_HOME is not None:
- return
- # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
- # in that case.
- warnings.warn(
- f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
- "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
- "only images whose names contain 'devel' will provide nvcc."
- )
- # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
- def check_env_flag(name: str, default: str = "") -> bool:
- return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
- # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py
- def is_offline_build() -> bool:
- """
- Downstream projects and distributions which bootstrap their own dependencies from scratch
- and run builds in offline sandboxes
- may set `FLASH_ATTENTION_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading
- pinned dependencies from the internet or at using dependencies vendored in-tree.
- Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).
- Missing dependencies lead to an early abortion.
- Dependencies' compatibility is not verified.
- Note that this flag isn't tested by the CI and does not provide any guarantees.
- """
- return check_env_flag("FLASH_ATTENTION_OFFLINE_BUILD", "")
- # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py
- def get_flashattn_cache_path():
- user_home = os.getenv("FLASH_ATTENTION_HOME")
- if not user_home:
- user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None
- if not user_home:
- raise RuntimeError("Could not find user home directory")
- return os.path.join(user_home, ".flashattn")
- def open_url(url):
- user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'
- headers = {
- 'User-Agent': user_agent,
- }
- request = urllib.request.Request(url, None, headers)
- # Set timeout to 300 seconds to prevent the request from hanging forever.
- return urllib.request.urlopen(request, timeout=300)
- def download_and_copy(name, src_path, dst_path, version, url_func):
- if is_offline_build():
- return
- flashattn_cache_path = get_flashattn_cache_path()
- base_dir = os.path.dirname(__file__)
- system = platform.system()
- try:
- arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
- except KeyError:
- arch = platform.machine()
- supported = {"Linux": "linux", "Darwin": "linux"}
- url = url_func(supported[system], arch, version)
- tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download
- dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
- platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
- src_path = src_path(platform_name, version) if callable(src_path) else src_path
- src_path = os.path.join(tmp_path, src_path)
- download = not os.path.exists(src_path)
- if download:
- print(f'downloading and extracting {url} ...')
- file = tarfile.open(fileobj=open_url(url), mode="r|*")
- file.extractall(path=tmp_path)
- os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
- print(f'copy {src_path} to {dst_path} ...')
- if os.path.isdir(src_path):
- shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
- else:
- shutil.copy(src_path, dst_path)
- def nvcc_threads_args():
- nvcc_threads = os.getenv("NVCC_THREADS") or "4"
- return ["--threads", nvcc_threads]
- NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"}
- exe_extension = sysconfig.get_config_var("EXE")
- cmdclass = {}
- ext_modules = []
- # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
- # files included in the source distribution, in case the user compiles from source.
- subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"])
- if not SKIP_CUDA_BUILD:
- print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
- TORCH_MAJOR = int(torch.__version__.split(".")[0])
- TORCH_MINOR = int(torch.__version__.split(".")[1])
- check_if_cuda_home_none("flash_attn")
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
- if bare_metal_version < Version("12.3"):
- raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above")
- if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently
- download_and_copy(
- name="nvcc", src_path=f"bin", dst_path="bin",
- version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version:
- ((lambda version_major, version_minor1, version_minor2:
- f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
- (*version.split('.'))))
- download_and_copy(
- name="nvcc", src_path=f"nvvm/bin", dst_path="bin",
- version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version:
- ((lambda version_major, version_minor1, version_minor2:
- f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
- (*version.split('.'))))
- base_dir = os.path.dirname(__file__)
- ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")
- nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}")
- # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc
- os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"]
- os.environ["PYTORCH_NVCC"] = nvcc_path_new
- # Make nvcc executable, sometimes after the copy it loses its permissions
- os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC)
- cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90a,code=sm_90a")
- # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
- # torch._C._GLIBCXX_USE_CXX11_ABI
- # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
- if FORCE_CXX11_ABI:
- torch._C._GLIBCXX_USE_CXX11_ABI = True
- repo_dir = Path(this_dir).parent
- cutlass_dir = repo_dir / "csrc" / "cutlass"
- feature_args = (
- []
- + (["-DFLASHATTENTION_DISABLE_BACKWARD"] if DISABLE_BACKWARD else [])
- + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else [])
- + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else [])
- + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else [])
- + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else [])
- + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else [])
- + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else [])
- + (["-DFLASHATTENTION_DISABLE_FP16"] if DISABLE_FP16 else [])
- + (["-DFLASHATTENTION_DISABLE_FP8"] if DISABLE_FP8 else [])
- + (["-DFLASHATTENTION_DISABLE_VARLEN"] if DISABLE_VARLEN else [])
- + (["-DFLASHATTENTION_DISABLE_CLUSTER"] if DISABLE_CLUSTER else [])
- + (["-DFLASHATTENTION_DISABLE_HDIM64"] if DISABLE_HDIM64 else [])
- + (["-DFLASHATTENTION_DISABLE_HDIM96"] if DISABLE_HDIM96 else [])
- + (["-DFLASHATTENTION_DISABLE_HDIM128"] if DISABLE_HDIM128 else [])
- + (["-DFLASHATTENTION_DISABLE_HDIM192"] if DISABLE_HDIM192 else [])
- + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else [])
- + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else [])
- )
- DTYPE_FWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else [])
- DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else [])
- HEAD_DIMENSIONS_BWD = (
- []
- + ([64] if not DISABLE_HDIM64 else [])
- + ([96] if not DISABLE_HDIM96 else [])
- + ([128] if not DISABLE_HDIM128 else [])
- + ([192] if not DISABLE_HDIM192 else [])
- + ([256] if not DISABLE_HDIM256 else [])
- )
- # HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD
- HEAD_DIMENSIONS_FWD = ["all"]
- SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else [])
- PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else [])
- SOFTCAP = [""] + (["_softcap"] if not DISABLE_SOFTCAP else [])
- PACKGQA = [""] + (["_packgqa"] if not DISABLE_PACKGQA else [])
- sources_fwd = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu"
- for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD, SPLIT, PAGEDKV, SOFTCAP, PACKGQA)]
- sources_bwd = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}_sm90.cu"
- for hdim, dtype in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD)]
- if DISABLE_BACKWARD:
- sources_bwd = []
- sources = ["flash_api.cpp"] + sources_fwd + sources_bwd
- if not DISABLE_SPLIT:
- sources += ["flash_fwd_combine_sm80.cu"]
- nvcc_flags = [
- "-O3",
- "-std=c++17",
- "--ftemplate-backtrace-limit=0", # To debug template code
- "--use_fast_math",
- # "--keep",
- # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers
- # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster
- "--resource-usage", # printing out number of registers
- "-lineinfo",
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use
- # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL
- "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging
- "-DNDEBUG", # Important, otherwise performance is severely impacted
- ]
- if get_platform() == "win_amd64":
- nvcc_flags.extend(
- [
- "-D_USE_MATH_DEFINES", # for M_LN2
- "-Xcompiler=/Zc:__cplusplus", # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd
- ]
- )
- include_dirs = [
- Path(this_dir),
- cutlass_dir / "include",
- ]
- ext_modules.append(
- CUDAExtension(
- name="flash_attn_3_cuda",
- sources=sources,
- extra_compile_args={
- "cxx": ["-O3", "-std=c++17"] + feature_args,
- "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args,
- },
- include_dirs=include_dirs,
- )
- )
- def get_package_version():
- with open(Path(this_dir) / "__init__.py", "r") as f:
- version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
- public_version = ast.literal_eval(version_match.group(1))
- local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
- if local_version:
- return f"{public_version}+{local_version}"
- else:
- return str(public_version)
- def get_wheel_url():
- # Determine the version numbers that will be used to determine the correct wheel
- # We're using the CUDA version used to build torch, not the one currently installed
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
- torch_cuda_version = parse(torch.version.cuda)
- torch_version_raw = parse(torch.__version__)
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
- # to save CI time. Minor versions should be compatible.
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
- python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
- platform_name = get_platform()
- package_version = get_package_version()
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
- torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
- cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
- # Determine wheel URL based on CUDA version, torch version, python version and OS
- wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
- wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename)
- return wheel_url, wheel_filename
- class CachedWheelsCommand(_bdist_wheel):
- """
- The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
- find an existing wheel (which is currently the case for all installs). We use
- the environment parameters to detect whether there is already a pre-built version of a compatible
- wheel available and short-circuits the standard full build pipeline.
- """
- def run(self):
- if FORCE_BUILD:
- return super().run()
- wheel_url, wheel_filename = get_wheel_url()
- print("Guessing wheel URL: ", wheel_url)
- try:
- urllib.request.urlretrieve(wheel_url, wheel_filename)
- # Make the archive
- # Lifted from the root wheel processing command
- # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
- if not os.path.exists(self.dist_dir):
- os.makedirs(self.dist_dir)
- impl_tag, abi_tag, plat_tag = self.get_tag()
- archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
- wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
- print("Raw wheel path", wheel_path)
- shutil.move(wheel_filename, wheel_path)
- except urllib.error.HTTPError:
- print("Precompiled wheel not found. Building from source...")
- # If the wheel could not be downloaded, build from source
- super().run()
- setup(
- name=PACKAGE_NAME,
- version=get_package_version(),
- packages=find_packages(
- exclude=(
- "build",
- "csrc",
- "include",
- "tests",
- "dist",
- "docs",
- "benchmarks",
- )
- ),
- py_modules=["flash_attn_interface"],
- description="FlashAttention-3",
- long_description=long_description,
- long_description_content_type="text/markdown",
- classifiers=[
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
- "Operating System :: Unix",
- ],
- ext_modules=ext_modules,
- cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
- if ext_modules
- else {
- "bdist_wheel": CachedWheelsCommand,
- },
- python_requires=">=3.8",
- install_requires=[
- "torch",
- "einops",
- "packaging",
- "ninja",
- ],
- )
|