|
@@ -22,7 +22,8 @@ def get_cuda_bare_metal_version(cuda_dir):
|
|
def append_nvcc_threads(nvcc_extra_args):
|
|
def append_nvcc_threads(nvcc_extra_args):
|
|
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
|
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
|
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
|
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
|
- return nvcc_extra_args + ["--threads", "4"]
|
|
|
|
|
|
+ nvcc_threads = os.getenv("NVCC_THREADS") or "4"
|
|
|
|
+ return nvcc_extra_args + ["--threads", nvcc_threads]
|
|
return nvcc_extra_args
|
|
return nvcc_extra_args
|
|
|
|
|
|
|
|
|