123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import os
- import subprocess
- from packaging.version import parse, Version
- import torch
- from setuptools import setup
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
- def get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- bare_metal_version = parse(output[release_idx].split(",")[0])
- return raw_output, bare_metal_version
- def append_nvcc_threads(nvcc_extra_args):
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
- if bare_metal_version >= Version("11.2"):
- nvcc_threads = os.getenv("NVCC_THREADS") or "4"
- return nvcc_extra_args + ["--threads", nvcc_threads]
- return nvcc_extra_args
- setup(
- name='fused_dense_lib',
- ext_modules=[
- CUDAExtension(
- name='fused_dense_lib',
- sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
- extra_compile_args={
- 'cxx': ['-O3',],
- 'nvcc': append_nvcc_threads(['-O3'])
- }
- )
- ],
- cmdclass={
- 'build_ext': BuildExtension
- })
|