Просмотр исходного кода

fix: set CPU Affinity (#187)

* Set cpu affinity for each ray process.

* Try after all workers are ready.

* Added check for hyper-threading and set worker affinities based on that if true.

* Cleanup

* Fix lint

Also change is_hyperthreading to ht_scale

* Fix unusual modulus assignment?
KaraKaraWitch 1 год назад
Родитель
Сommit
9a0b5a197d
1 измененных файлов с 23 добавлено и 0 удалено
  1. 23 0
      aphrodite/engine/aphrodite_engine.py

+ 23 - 0
aphrodite/engine/aphrodite_engine.py

@@ -3,6 +3,8 @@ import time
 from functools import partial
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
 
+import psutil
+
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
@@ -197,6 +199,27 @@ class AphroditeEngine:
             max_concurrent_workers=self.parallel_config.
             max_parallel_loading_workers,
         )
+        
+        # HACK
+        # After running ray.init(), ray processes affinity is set to (0,1).
+        # (or whatever the CPU scheduler fancies)
+        # We however want the actual workers that are being used,
+        # so we call here since calling after ray.init() and everything else.
+        # We reassign each ray process by taking the
+        # modulus of the number of cpu_cores available.
+        # Issue: https://github.com/PygmalionAI/aphrodite-engine/issues/115
+        # The solution is similar to the taskset solution linked above.
+        current_process = psutil.Process()
+        ray_threads = 0
+        logical_cores = psutil.cpu_count(logical=True)
+        physical_cores = psutil.cpu_count(logical=False)
+        ht_scale = physical_cores / logical_cores
+        for process in current_process.children(recursive=True):
+            # process.pid
+            if "ray::" in process.name():
+                process.cpu_affinity([ray_threads])
+                ray_threads += int(1 * ht_scale) if ht_scale > 1.0 else 1
+                ray_threads = ray_threads % logical_cores
 
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)