import os import subprocess from packaging.version import parse, Version import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_version def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.2"): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] return nvcc_extra_args setup( name='fused_dense_lib', ext_modules=[ CUDAExtension( name='fused_dense_lib', sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], extra_compile_args={ 'cxx': ['-O3',], 'nvcc': append_nvcc_threads(['-O3']) } ) ], cmdclass={ 'build_ext': BuildExtension })