Browse Source

Make nvcc threads configurable via environment variable (#885)

Chirag Jain 1 năm trước cách đây
mục cha
commit
50896ec574

+ 2 - 1
csrc/ft_attention/setup.py

@@ -55,7 +55,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
     if bare_metal_version >= Version("11.2"):
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
csrc/fused_dense_lib/setup.py

@@ -19,7 +19,8 @@ def get_cuda_bare_metal_version(cuda_dir):
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
     if bare_metal_version >= Version("11.2"):
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
csrc/fused_softmax/setup.py

@@ -22,7 +22,8 @@ def get_cuda_bare_metal_version(cuda_dir):
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
     if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
csrc/layer_norm/setup.py

@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
     if bare_metal_version >= Version("11.2"):
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
csrc/rotary/setup.py

@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
     if bare_metal_version >= Version("11.2"):
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
csrc/xentropy/setup.py

@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
 def append_nvcc_threads(nvcc_extra_args):
     _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
     if bare_metal_version >= Version("11.2"):
-        return nvcc_extra_args + ["--threads", "4"]
+        nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+        return nvcc_extra_args + ["--threads", nvcc_threads]
     return nvcc_extra_args
 
 

+ 2 - 1
setup.py

@@ -83,7 +83,8 @@ def check_if_cuda_home_none(global_option: str) -> None:
 
 
 def append_nvcc_threads(nvcc_extra_args):
-    return nvcc_extra_args + ["--threads", "4"]
+    nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+    return nvcc_extra_args + ["--threads", nvcc_threads]
 
 
 cmdclass = {}