setup.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. import subprocess
  3. import torch
  4. from setuptools import setup
  5. from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
  6. def get_cuda_bare_metal_version(cuda_dir):
  7. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  8. output = raw_output.split()
  9. release_idx = output.index("release") + 1
  10. release = output[release_idx].split(".")
  11. bare_metal_major = release[0]
  12. bare_metal_minor = release[1][0]
  13. return raw_output, bare_metal_major, bare_metal_minor
  14. def append_nvcc_threads(nvcc_extra_args):
  15. _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
  16. if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
  17. return nvcc_extra_args + ["--threads", "4"]
  18. return nvcc_extra_args
  19. setup(
  20. name='fused_dense_lib',
  21. ext_modules=[
  22. CUDAExtension(
  23. name='fused_dense_lib',
  24. sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
  25. extra_compile_args={
  26. 'cxx': ['-O3',],
  27. 'nvcc': append_nvcc_threads(['-O3'])
  28. }
  29. )
  30. ],
  31. cmdclass={
  32. 'build_ext': BuildExtension
  33. })