|
@@ -8,30 +8,85 @@ import warnings
|
|
|
from packaging.version import parse, Version
|
|
|
import setuptools
|
|
|
import torch
|
|
|
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
|
|
+from torch.utils.cpp_extension import (
|
|
|
+ BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME)
|
|
|
|
|
|
ROOT_DIR = os.path.dirname(__file__)
|
|
|
|
|
|
MAIN_CUDA_VERSION = "11.8"
|
|
|
|
|
|
# Supported NVIDIA GPU architectures.
|
|
|
-SUPPORTED_ARCHS = {
|
|
|
+NVIDIA_SUPPORTED_ARCHS = {
|
|
|
"6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
|
|
|
}
|
|
|
+ROCM_SUPPORTED_ARCHS = {
|
|
|
+ "gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"
|
|
|
+}
|
|
|
+
|
|
|
+def _is_hip() -> bool:
|
|
|
+ return torch.version.hip is not None
|
|
|
+
|
|
|
+def _is_cuda() -> bool:
|
|
|
+ return torch.version.cuda is not None
|
|
|
+
|
|
|
|
|
|
# Compiler flags.
|
|
|
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
|
|
# TODO: Should we use -O3?
|
|
|
NVCC_FLAGS = ["-O2", "-std=c++17"]
|
|
|
|
|
|
-ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
|
|
-CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|
|
-NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|
|
|
|
|
-if CUDA_HOME is None:
|
|
|
+if _is_hip():
|
|
|
+ if ROCM_HOME is None:
|
|
|
+ raise RuntimeError(
|
|
|
+ "Cannot find ROCM_HOME. ROCm must be available to build the "
|
|
|
+ "package.")
|
|
|
+ NVCC_FLAGS += ["-DUSE_ROCM"]
|
|
|
+
|
|
|
+if _is_cuda() and CUDA_HOME is None:
|
|
|
raise RuntimeError(
|
|
|
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
|
|
|
|
|
+ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
|
|
+CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|
|
+NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|
|
+
|
|
|
+def get_amdgpu_offload_arch():
|
|
|
+ command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
|
|
+ try:
|
|
|
+ output = subprocess.check_output([command])
|
|
|
+ return output.decode('utf-8').strip()
|
|
|
+ except subprocess.CalledProcessError as e:
|
|
|
+ error_message = f"Error: {e}"
|
|
|
+ raise RuntimeError(error_message) from e
|
|
|
+ except FileNotFoundError as e:
|
|
|
+ # If the command is not found, print an error message
|
|
|
+ error_message = f"The command {command} was not found."
|
|
|
+ raise RuntimeError(error_message) from e
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def get_hipcc_rocm_version():
|
|
|
+ # Run the hipcc --version command
|
|
|
+ result = subprocess.run(['hipcc', '--version'],
|
|
|
+ stdout=subprocess.PIPE,
|
|
|
+ stderr=subprocess.STDOUT,
|
|
|
+ text=True)
|
|
|
+
|
|
|
+ # Check if the command was executed successfully
|
|
|
+ if result.returncode != 0:
|
|
|
+ print("Error running 'hipcc --version'")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Extract the version using a regular expression
|
|
|
+ match = re.search(r'HIP version: (\S+)', result.stdout)
|
|
|
+ if match:
|
|
|
+ # Return the version string
|
|
|
+ return match.group(1)
|
|
|
+ else:
|
|
|
+ print("Could not find HIP version in the output")
|
|
|
+ return None
|
|
|
|
|
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
|
|
"""Get the CUDA version from nvcc.
|
|
@@ -63,20 +118,22 @@ def get_torch_arch_list() -> Set[str]:
|
|
|
return set()
|
|
|
|
|
|
# Filter out the invalid architectures and print a warning.
|
|
|
- valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
|
|
+ valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
|
|
|
+ {s + "+PTX"
|
|
|
+ for s in NVIDIA_SUPPORTED_ARCHS})
|
|
|
arch_list = torch_arch_list.intersection(valid_archs)
|
|
|
# If none of the specified architectures are valid, raise an error.
|
|
|
if not arch_list:
|
|
|
raise RuntimeError(
|
|
|
- "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
|
|
- f"variable ({env_arch_list}) is supported. "
|
|
|
+ "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` "
|
|
|
+ f"env variable ({env_arch_list}) is supported. "
|
|
|
f"Supported CUDA architectures are: {valid_archs}.")
|
|
|
invalid_arch_list = torch_arch_list - valid_archs
|
|
|
if invalid_arch_list:
|
|
|
warnings.warn(
|
|
|
- f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
|
|
+ f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
|
|
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
|
|
- f"({env_arch_list}). Supported CUDA architectures are: "
|
|
|
+ f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
|
|
f"{valid_archs}.",
|
|
|
stacklevel=2)
|
|
|
return arch_list
|
|
@@ -84,7 +141,7 @@ def get_torch_arch_list() -> Set[str]:
|
|
|
|
|
|
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
|
|
compute_capabilities = get_torch_arch_list()
|
|
|
-if not compute_capabilities:
|
|
|
+if _is_cuda() and not compute_capabilities:
|
|
|
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
|
|
# GPUs on the current machine.
|
|
|
device_count = torch.cuda.device_count()
|
|
@@ -95,72 +152,87 @@ if not compute_capabilities:
|
|
|
"GPUs with compute capability below 6.0 are not supported.")
|
|
|
compute_capabilities.add(f"{major}.{minor}")
|
|
|
|
|
|
-nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
|
|
-if not compute_capabilities:
|
|
|
- # If no GPU is specified nor available, add all supported architectures
|
|
|
- # based on the NVCC CUDA version.
|
|
|
- compute_capabilities = SUPPORTED_ARCHS.copy()
|
|
|
- if nvcc_cuda_version < Version("11.1"):
|
|
|
- compute_capabilities.remove("8.6")
|
|
|
+if _is_cuda():
|
|
|
+ nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
|
|
+ if not compute_capabilities:
|
|
|
+ # If no GPU is specified nor available, add all supported architectures
|
|
|
+ # based on the NVCC CUDA version.
|
|
|
+ compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
|
|
|
+ if nvcc_cuda_version < Version("11.1"):
|
|
|
+ compute_capabilities.remove("8.6")
|
|
|
+ if nvcc_cuda_version < Version("11.8"):
|
|
|
+ compute_capabilities.remove("8.9")
|
|
|
+ compute_capabilities.remove("9.0")
|
|
|
+ # Validate the NVCC CUDA version.
|
|
|
+ if nvcc_cuda_version < Version("11.0"):
|
|
|
+ raise RuntimeError(
|
|
|
+ "CUDA 11.0 or higher is required to build the package.")
|
|
|
+ if (nvcc_cuda_version < Version("11.1")
|
|
|
+ and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
|
|
+ raise RuntimeError(
|
|
|
+ "CUDA 11.1 or higher is required for compute capability 8.6.")
|
|
|
if nvcc_cuda_version < Version("11.8"):
|
|
|
- compute_capabilities.remove("8.9")
|
|
|
- compute_capabilities.remove("9.0")
|
|
|
-
|
|
|
-# Validate the NVCC CUDA version.
|
|
|
-if nvcc_cuda_version < Version("11.0"):
|
|
|
- raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
|
|
-if (nvcc_cuda_version < Version("11.1")
|
|
|
- and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
|
|
- raise RuntimeError(
|
|
|
- "CUDA 11.1 or higher is required for compute capability 8.6.")
|
|
|
-if nvcc_cuda_version < Version("11.8"):
|
|
|
- if any(cc.startswith("8.9") for cc in compute_capabilities):
|
|
|
- # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
|
|
- # However, GPUs with compute capability 8.9 can also run the code generated by
|
|
|
- # the previous versions of CUDA 11 and targeting compute capability 8.0.
|
|
|
- # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
|
|
- # instead of 8.9.
|
|
|
- warnings.warn(
|
|
|
- "CUDA 11.8 or higher is required for compute capability 8.9. "
|
|
|
- "Targeting compute capability 8.0 instead.",
|
|
|
- stacklevel=2)
|
|
|
- compute_capabilities = set(cc for cc in compute_capabilities
|
|
|
- if not cc.startswith("8.9"))
|
|
|
- compute_capabilities.add("8.0+PTX")
|
|
|
- if any(cc.startswith("9.0") for cc in compute_capabilities):
|
|
|
+ if any(cc.startswith("8.9") for cc in compute_capabilities):
|
|
|
+ # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
|
|
+ # However, GPUs with compute capability 8.9 can also run the code generated by
|
|
|
+ # the previous versions of CUDA 11 and targeting compute capability 8.0.
|
|
|
+ # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
|
|
+ # instead of 8.9.
|
|
|
+ warnings.warn(
|
|
|
+ "CUDA 11.8 or higher is required for compute capability 8.9. "
|
|
|
+ "Targeting compute capability 8.0 instead.",
|
|
|
+ stacklevel=2)
|
|
|
+ compute_capabilities = set(cc for cc in compute_capabilities
|
|
|
+ if not cc.startswith("8.9"))
|
|
|
+ compute_capabilities.add("8.0+PTX")
|
|
|
+ if any(cc.startswith("9.0") for cc in compute_capabilities):
|
|
|
+ raise RuntimeError(
|
|
|
+ "CUDA 11.8 or higher is required for compute capability 9.0.")
|
|
|
+
|
|
|
+ # Add target compute capabilities to NVCC flags.
|
|
|
+ for capability in compute_capabilities:
|
|
|
+ num = capability[0] + capability[2]
|
|
|
+ NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
|
|
+ if capability.endswith("+PTX"):
|
|
|
+ NVCC_FLAGS += [
|
|
|
+ "-gencode", f"arch=compute_{num},code=compute_{num}"
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Use NVCC threads to parallelize the build.
|
|
|
+ if nvcc_cuda_version >= Version("11.2"):
|
|
|
+ num_threads = min(os.cpu_count(), 8)
|
|
|
+ NVCC_FLAGS += ["--threads", str(num_threads)]
|
|
|
+
|
|
|
+elif _is_hip():
|
|
|
+ amd_arch = get_amdgpu_offload_arch()
|
|
|
+ if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
|
|
raise RuntimeError(
|
|
|
- "CUDA 11.8 or higher is required for compute capability 9.0.")
|
|
|
+ f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
|
|
+ f"amdgpu_arch_found: {amd_arch}")
|
|
|
|
|
|
-# Add target compute capabilities to NVCC flags.
|
|
|
-for capability in compute_capabilities:
|
|
|
- num = capability[0] + capability[2]
|
|
|
- NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
|
|
- if capability.endswith("+PTX"):
|
|
|
- NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
|
|
+ext_modules = []
|
|
|
|
|
|
-# Use NVCC threads to parallelize the build.
|
|
|
-if nvcc_cuda_version >= Version("11.2"):
|
|
|
- num_threads = min(os.cpu_count(), 8)
|
|
|
- NVCC_FLAGS += ["--threads", str(num_threads)]
|
|
|
+aphrodite_extension_sources = [
|
|
|
+ "kernels/cache_kernels.cu",
|
|
|
+ "kernels/attention/attention_kernels.cu",
|
|
|
+ "kernels/pos_encoding_kernels.cu",
|
|
|
+ "kernels/activation_kernels.cu",
|
|
|
+ "kernels/layernorm_kernels.cu",
|
|
|
+ "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
|
|
|
+ "kernels/quantization/gptq/exllama_ext.cpp",
|
|
|
+ "kernels/quantization/gptq/q_matrix.cu",
|
|
|
+ "kernels/quantization/gptq/q_gemm.cu",
|
|
|
+ "kernels/quantization/gptq/old_matmul_kernel.cu",
|
|
|
+ "kernels/cuda_utils_kernels.cu",
|
|
|
+ "kernels/pybind.cpp",
|
|
|
+]
|
|
|
+
|
|
|
+if _is_cuda():
|
|
|
+ aphrodite_extension_sources.append("kernels/quantization/awq/gemm_kernels.cu")
|
|
|
|
|
|
-ext_modules = []
|
|
|
aphrodite_extension = CUDAExtension(
|
|
|
name="aphrodite._C",
|
|
|
- sources=[
|
|
|
- "kernels/cache_kernels.cu",
|
|
|
- "kernels/attention/attention_kernels.cu",
|
|
|
- "kernels/pos_encoding_kernels.cu",
|
|
|
- "kernels/activation_kernels.cu",
|
|
|
- "kernels/layernorm_kernels.cu",
|
|
|
- "kernels/quantization/awq/gemm_kernels.cu",
|
|
|
- "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
|
|
|
- "kernels/quantization/gptq/exllama_ext.cpp",
|
|
|
- "kernels/quantization/gptq/q_matrix.cu",
|
|
|
- "kernels/quantization/gptq/q_gemm.cu",
|
|
|
- "kernels/quantization/gptq/old_matmul_kernel.cu",
|
|
|
- "kernels/cuda_utils_kernels.cu",
|
|
|
- "kernels/pybind.cpp",
|
|
|
- ],
|
|
|
+ sources=aphrodite_extension_sources,
|
|
|
extra_compile_args={
|
|
|
"cxx": CXX_FLAGS,
|
|
|
"nvcc": NVCC_FLAGS,
|
|
@@ -188,21 +260,29 @@ def find_version(filepath: str) -> str:
|
|
|
|
|
|
def get_aphrodite_version() -> str:
|
|
|
version = find_version(get_path("aphrodite-engine", "__init__.py"))
|
|
|
- cuda_version = str(nvcc_cuda_version)
|
|
|
|
|
|
- # Split the version into numerical and suffix parts
|
|
|
- version_parts = version.split('-')
|
|
|
- version_num = version_parts[0]
|
|
|
- version_suffix = version_parts[1] if len(version_parts) > 1 else ''
|
|
|
-
|
|
|
- if cuda_version != MAIN_CUDA_VERSION:
|
|
|
- cuda_version_str = cuda_version.replace(".", "")[:3]
|
|
|
- version_num += f"+cu{cuda_version_str}"
|
|
|
-
|
|
|
- # Reassemble the version string with the suffix, if any
|
|
|
- version = version_num + ('-' + version_suffix if version_suffix else '')
|
|
|
-
|
|
|
- return version
|
|
|
+ if _is_hip():
|
|
|
+ # get the HIP version
|
|
|
+
|
|
|
+ hipcc_version = get_hipcc_rocm_version()
|
|
|
+ if hipcc_version != MAIN_CUDA_VERSION:
|
|
|
+ rocm_version_str = hipcc_version.replace(".", "")[:3]
|
|
|
+ version += f"+rocm{rocm_version_str}"
|
|
|
+ else:
|
|
|
+ cuda_version = str(nvcc_cuda_version)
|
|
|
+ # Split the version into numerical and suffix parts
|
|
|
+ version_parts = version.split('-')
|
|
|
+ version_num = version_parts[0]
|
|
|
+ version_suffix = version_parts[1] if len(version_parts) > 1 else ''
|
|
|
+
|
|
|
+ if cuda_version != MAIN_CUDA_VERSION:
|
|
|
+ cuda_version_str = cuda_version.replace(".", "")[:3]
|
|
|
+ version_num += f"+cu{cuda_version_str}"
|
|
|
+
|
|
|
+ # Reassemble the version string with the suffix, if any
|
|
|
+ version = version_num + ('-' + version_suffix if version_suffix else '')
|
|
|
+
|
|
|
+ return version
|
|
|
|
|
|
|
|
|
def read_readme() -> str:
|
|
@@ -216,8 +296,12 @@ def read_readme() -> str:
|
|
|
|
|
|
def get_requirements() -> List[str]:
|
|
|
"""Get Python package dependencies from requirements.txt."""
|
|
|
- with open(get_path("requirements.txt")) as f:
|
|
|
- requirements = f.read().strip().split("\n")
|
|
|
+ if _is_hip():
|
|
|
+ with open(get_path("requirements-rocm.txt")) as f:
|
|
|
+ requirements = f.read().strip().split("\n")
|
|
|
+ else:
|
|
|
+ with open(get_path("requirements.txt")) as f:
|
|
|
+ requirements = f.read().strip().split("\n")
|
|
|
return requirements
|
|
|
|
|
|
|
|
@@ -251,4 +335,4 @@ setuptools.setup(
|
|
|
ext_modules=ext_modules,
|
|
|
cmdclass={"build_ext": BuildExtension},
|
|
|
package_data={"aphrodite-engine": ["py.typed"]},
|
|
|
-)
|
|
|
+)
|