Procházet zdrojové kódy

move func tracing to utils

AlpinDale před 8 měsíci
rodič
revize
2e0b115ce1
2 změnil soubory, kde provedl 19 přidání a 14 odebrání
  1. 17 0
      aphrodite/common/utils.py
  2. 2 14
      aphrodite/task_handler/worker_base.py

+ 17 - 0
aphrodite/common/utils.py

@@ -1,10 +1,13 @@
 import asyncio
+import datetime
 import enum
 import gc
 import glob
 import os
 import socket
 import subprocess
+import tempfile
+import threading
 import uuid
 import warnings
 from collections import defaultdict
@@ -19,6 +22,8 @@ import torch
 from loguru import logger
 from packaging.version import Version, parse
 
+from aphrodite.common.logger import enable_trace_function_call
+
 T = TypeVar("T")
 
 STR_DTYPE_TO_TORCH_DTYPE = {
@@ -604,3 +609,15 @@ def find_nccl_library():
             raise ValueError("NCCL only supports CUDA and ROCm backends.")
         logger.info(f"Found nccl from library {so_file}")
     return so_file
+
+
+def enable_trace_function_call_for_thread() -> None:
+    if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
+        tmp_dir = tempfile.gettempdir()
+        filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
+                    f"_thread_{threading.get_ident()}_"
+                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
+        log_path = os.path.join(tmp_dir, "aphrodite",
+                                get_aphrodite_instance_id(), filename)
+        os.makedirs(os.path.dirname(log_path), exist_ok=True)
+        enable_trace_function_call(log_path)

+ 2 - 14
aphrodite/task_handler/worker_base.py

@@ -1,16 +1,12 @@
-import datetime
 import importlib
 import os
-import tempfile
-import threading
 from abc import ABC, abstractmethod
 from typing import Dict, List, Set, Tuple
 
 from loguru import logger
 
-from aphrodite.common.logger import enable_trace_function_call
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (get_aphrodite_instance_id,
+from aphrodite.common.utils import (enable_trace_function_call_for_thread,
                                     update_environment_variables)
 from aphrodite.lora.request import LoRARequest
 
@@ -129,15 +125,7 @@ class WorkerWrapperBase:
        function tracing if required.
         Arguments are passed to the worker class constructor.
         """
-        if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
-            tmp_dir = tempfile.gettempdir()
-            filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
-                        f"_thread_{threading.get_ident()}_"
-                        f"at_{datetime.datetime.now()}.log").replace(" ", "_")
-            log_path = os.path.join(tmp_dir, "aphrodite",
-                                    get_aphrodite_instance_id(), filename)
-            os.makedirs(os.path.dirname(log_path), exist_ok=True)
-            enable_trace_function_call(log_path)
+        enable_trace_function_call_for_thread()
 
         mod = importlib.import_module(self.worker_module_name)
         worker_class = getattr(mod, self.worker_class_name)