setup.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. import subprocess
  3. from packaging.version import parse, Version
  4. import torch
  5. from setuptools import setup
  6. from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
  7. def get_cuda_bare_metal_version(cuda_dir):
  8. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  9. output = raw_output.split()
  10. release_idx = output.index("release") + 1
  11. bare_metal_version = parse(output[release_idx].split(",")[0])
  12. return raw_output, bare_metal_version
  13. def append_nvcc_threads(nvcc_extra_args):
  14. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  15. if bare_metal_version >= Version("11.2"):
  16. nvcc_threads = os.getenv("NVCC_THREADS") or "4"
  17. return nvcc_extra_args + ["--threads", nvcc_threads]
  18. return nvcc_extra_args
  19. setup(
  20. name='fused_dense_lib',
  21. ext_modules=[
  22. CUDAExtension(
  23. name='fused_dense_lib',
  24. sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
  25. extra_compile_args={
  26. 'cxx': ['-O3',],
  27. 'nvcc': append_nvcc_threads(['-O3'])
  28. }
  29. )
  30. ],
  31. cmdclass={
  32. 'build_ext': BuildExtension
  33. })