123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- # Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- import sys
- import warnings
- import os
- import re
- import shutil
- import ast
- from pathlib import Path
- from packaging.version import parse, Version
- import platform
- from setuptools import setup, find_packages
- import subprocess
- import urllib.request
- import urllib.error
- from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
- import torch
- from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
- # with open("../README.md", "r", encoding="utf-8") as fh:
- with open("../README.md", "r", encoding="utf-8") as fh:
- long_description = fh.read()
- # ninja build does not work unless include_dirs are abs path
- this_dir = os.path.dirname(os.path.abspath(__file__))
- PACKAGE_NAME = "flashattn-hopper"
- BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
- # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
- # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
- FORCE_BUILD = os.getenv("FAHOPPER_FORCE_BUILD", "FALSE") == "TRUE"
- SKIP_CUDA_BUILD = os.getenv("FAHOPPER_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
- # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
- FORCE_CXX11_ABI = os.getenv("FAHOPPER_FORCE_CXX11_ABI", "FALSE") == "TRUE"
- def get_platform():
- """
- Returns the platform name as used in wheel filenames.
- """
- if sys.platform.startswith("linux"):
- return "linux_x86_64"
- elif sys.platform == "darwin":
- mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
- return f"macosx_{mac_version}_x86_64"
- elif sys.platform == "win32":
- return "win_amd64"
- else:
- raise ValueError("Unsupported platform: {}".format(sys.platform))
- def get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- bare_metal_version = parse(output[release_idx].split(",")[0])
- return raw_output, bare_metal_version
- def check_if_cuda_home_none(global_option: str) -> None:
- if CUDA_HOME is not None:
- return
- # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
- # in that case.
- warnings.warn(
- f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
- "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
- "only images whose names contain 'devel' will provide nvcc."
- )
- def append_nvcc_threads(nvcc_extra_args):
- return nvcc_extra_args + ["--threads", "4"]
- cmdclass = {}
- ext_modules = []
- # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
- # files included in the source distribution, in case the user compiles from source.
- subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"])
- if not SKIP_CUDA_BUILD:
- print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
- TORCH_MAJOR = int(torch.__version__.split(".")[0])
- TORCH_MINOR = int(torch.__version__.split(".")[1])
- check_if_cuda_home_none("--fahopper")
- cc_flag = []
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
- if bare_metal_version < Version("12.3"):
- raise RuntimeError("FA Hopper is only supported on CUDA 12.3 and above")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_90a,code=sm_90a")
- # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
- # torch._C._GLIBCXX_USE_CXX11_ABI
- # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
- if FORCE_CXX11_ABI:
- torch._C._GLIBCXX_USE_CXX11_ABI = True
- repo_dir = Path(this_dir).parent
- cutlass_dir = repo_dir / "csrc" / "cutlass"
- sources = [
- "flash_api.cpp",
- "flash_fwd_hdim64_fp16_sm90.cu",
- "flash_fwd_hdim64_bf16_sm90.cu",
- "flash_fwd_hdim128_fp16_sm90.cu",
- "flash_fwd_hdim128_bf16_sm90.cu",
- "flash_fwd_hdim256_fp16_sm90.cu",
- "flash_fwd_hdim256_bf16_sm90.cu",
- "flash_bwd_hdim64_fp16_sm90.cu",
- "flash_bwd_hdim96_fp16_sm90.cu",
- "flash_bwd_hdim128_fp16_sm90.cu",
- # "flash_bwd_hdim256_fp16_sm90.cu",
- "flash_bwd_hdim64_bf16_sm90.cu",
- "flash_bwd_hdim96_bf16_sm90.cu",
- "flash_bwd_hdim128_bf16_sm90.cu",
- "flash_fwd_hdim64_e4m3_sm90.cu",
- "flash_fwd_hdim128_e4m3_sm90.cu",
- "flash_fwd_hdim256_e4m3_sm90.cu",
- "flash_fwd_hdim64_fp16_gqa2_sm90.cu",
- "flash_fwd_hdim64_fp16_gqa4_sm90.cu",
- "flash_fwd_hdim64_fp16_gqa8_sm90.cu",
- "flash_fwd_hdim64_fp16_gqa16_sm90.cu",
- "flash_fwd_hdim64_fp16_gqa32_sm90.cu",
- "flash_fwd_hdim128_fp16_gqa2_sm90.cu",
- "flash_fwd_hdim128_fp16_gqa4_sm90.cu",
- "flash_fwd_hdim128_fp16_gqa8_sm90.cu",
- "flash_fwd_hdim128_fp16_gqa16_sm90.cu",
- "flash_fwd_hdim128_fp16_gqa32_sm90.cu",
- "flash_fwd_hdim256_fp16_gqa2_sm90.cu",
- "flash_fwd_hdim256_fp16_gqa4_sm90.cu",
- "flash_fwd_hdim256_fp16_gqa8_sm90.cu",
- "flash_fwd_hdim256_fp16_gqa16_sm90.cu",
- "flash_fwd_hdim256_fp16_gqa32_sm90.cu",
- "flash_fwd_hdim64_bf16_gqa2_sm90.cu",
- "flash_fwd_hdim64_bf16_gqa4_sm90.cu",
- "flash_fwd_hdim64_bf16_gqa8_sm90.cu",
- "flash_fwd_hdim64_bf16_gqa16_sm90.cu",
- "flash_fwd_hdim64_bf16_gqa32_sm90.cu",
- "flash_fwd_hdim128_bf16_gqa2_sm90.cu",
- "flash_fwd_hdim128_bf16_gqa4_sm90.cu",
- "flash_fwd_hdim128_bf16_gqa8_sm90.cu",
- "flash_fwd_hdim128_bf16_gqa16_sm90.cu",
- "flash_fwd_hdim128_bf16_gqa32_sm90.cu",
- "flash_fwd_hdim256_bf16_gqa2_sm90.cu",
- "flash_fwd_hdim256_bf16_gqa4_sm90.cu",
- "flash_fwd_hdim256_bf16_gqa8_sm90.cu",
- "flash_fwd_hdim256_bf16_gqa16_sm90.cu",
- "flash_fwd_hdim256_bf16_gqa32_sm90.cu",
- "flash_fwd_hdim64_e4m3_gqa2_sm90.cu",
- "flash_fwd_hdim64_e4m3_gqa4_sm90.cu",
- "flash_fwd_hdim64_e4m3_gqa8_sm90.cu",
- "flash_fwd_hdim64_e4m3_gqa16_sm90.cu",
- "flash_fwd_hdim64_e4m3_gqa32_sm90.cu",
- "flash_fwd_hdim128_e4m3_gqa2_sm90.cu",
- "flash_fwd_hdim128_e4m3_gqa4_sm90.cu",
- "flash_fwd_hdim128_e4m3_gqa8_sm90.cu",
- "flash_fwd_hdim128_e4m3_gqa16_sm90.cu",
- "flash_fwd_hdim128_e4m3_gqa32_sm90.cu",
- "flash_fwd_hdim256_e4m3_gqa2_sm90.cu",
- "flash_fwd_hdim256_e4m3_gqa4_sm90.cu",
- "flash_fwd_hdim256_e4m3_gqa8_sm90.cu",
- "flash_fwd_hdim256_e4m3_gqa16_sm90.cu",
- "flash_fwd_hdim256_e4m3_gqa32_sm90.cu",
- ]
- nvcc_flags = [
- "-O3",
- # "-O0",
- "-std=c++17",
- "-U__CUDA_NO_HALF_OPERATORS__",
- "-U__CUDA_NO_HALF_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
- "--expt-relaxed-constexpr",
- "--expt-extended-lambda",
- "--use_fast_math",
- "--ptxas-options=-v", # printing out number of registers
- "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers
- "-lineinfo",
- "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging
- "-DNDEBUG", # Important, otherwise performance is severely impacted
- ]
- if get_platform() == "win_amd64":
- nvcc_flags.extend(
- [
- "-D_USE_MATH_DEFINES", # for M_LN2
- "-Xcompiler=/Zc:__cplusplus", # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd
- ]
- )
- include_dirs = [
- # Path(this_dir) / "fmha-pipeline",
- # repo_dir / "lib",
- # repo_dir / "include",
- cutlass_dir / "include",
- # cutlass_dir / "examples" / "common",
- # cutlass_dir / "tools" / "util" / "include",
- ]
- ext_modules.append(
- CUDAExtension(
- name="flashattn_hopper_cuda",
- sources=sources,
- extra_compile_args={
- "cxx": ["-O3", "-std=c++17"],
- # "cxx": ["-O0", "-std=c++17"],
- "nvcc": append_nvcc_threads(
- nvcc_flags + cc_flag
- ),
- },
- include_dirs=include_dirs,
- # Without this we get and error about cuTensorMapEncodeTiled not defined
- libraries=["cuda"]
- )
- )
- # ext_modules.append(
- # CUDAExtension(
- # name="flashattn_hopper_cuda_ws",
- # sources=sources,
- # extra_compile_args={
- # "cxx": ["-O3", "-std=c++17"],
- # "nvcc": append_nvcc_threads(
- # nvcc_flags + ["-DEXECMODE=1"] + cc_flag
- # ),
- # },
- # include_dirs=include_dirs,
- # # Without this we get and error about cuTensorMapEncodeTiled not defined
- # libraries=["cuda"]
- # )
- # )
- def get_package_version():
- with open(Path(this_dir) / "__init__.py", "r") as f:
- version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
- public_version = ast.literal_eval(version_match.group(1))
- local_version = os.environ.get("FLASHATTN_HOPPER_LOCAL_VERSION")
- if local_version:
- return f"{public_version}+{local_version}"
- else:
- return str(public_version)
- def get_wheel_url():
- # Determine the version numbers that will be used to determine the correct wheel
- # We're using the CUDA version used to build torch, not the one currently installed
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
- torch_cuda_version = parse(torch.version.cuda)
- torch_version_raw = parse(torch.__version__)
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
- # to save CI time. Minor versions should be compatible.
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
- python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
- platform_name = get_platform()
- package_version = get_package_version()
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
- torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
- cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
- # Determine wheel URL based on CUDA version, torch version, python version and OS
- wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
- wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename)
- return wheel_url, wheel_filename
- class CachedWheelsCommand(_bdist_wheel):
- """
- The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
- find an existing wheel (which is currently the case for all installs). We use
- the environment parameters to detect whether there is already a pre-built version of a compatible
- wheel available and short-circuits the standard full build pipeline.
- """
- def run(self):
- if FORCE_BUILD:
- return super().run()
- wheel_url, wheel_filename = get_wheel_url()
- print("Guessing wheel URL: ", wheel_url)
- try:
- urllib.request.urlretrieve(wheel_url, wheel_filename)
- # Make the archive
- # Lifted from the root wheel processing command
- # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
- if not os.path.exists(self.dist_dir):
- os.makedirs(self.dist_dir)
- impl_tag, abi_tag, plat_tag = self.get_tag()
- archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
- wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
- print("Raw wheel path", wheel_path)
- shutil.move(wheel_filename, wheel_path)
- except urllib.error.HTTPError:
- print("Precompiled wheel not found. Building from source...")
- # If the wheel could not be downloaded, build from source
- super().run()
- setup(
- name=PACKAGE_NAME,
- version=get_package_version(),
- packages=find_packages(
- exclude=(
- "build",
- "csrc",
- "include",
- "tests",
- "dist",
- "docs",
- "benchmarks",
- )
- ),
- py_modules=["flash_attn_interface"],
- description="FlashAttention-3",
- long_description=long_description,
- long_description_content_type="text/markdown",
- classifiers=[
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
- "Operating System :: Unix",
- ],
- ext_modules=ext_modules,
- cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
- if ext_modules
- else {
- "bdist_wheel": CachedWheelsCommand,
- },
- python_requires=">=3.8",
- install_requires=[
- "torch",
- "einops",
- "packaging",
- "ninja",
- ],
- )
|