Prechádzať zdrojové kódy

torch.compile: allow adding custom compile backends via plugins (#1041)

AlpinDale 2 mesiacov pred
rodič
commit
9797d38b24

+ 13 - 0
aphrodite/plugins/__init__.py

@@ -1,3 +1,5 @@
+from typing import Callable, Optional, Union
+
 from loguru import logger
 
 import aphrodite.common.envs as envs
@@ -27,3 +29,14 @@ def load_general_plugins():
             except Exception:
                 logger.exception("Failed to load general plugin: "
                                  f"{plugin.name}")
+
+_torch_compile_backend: Optional[Union[Callable, str]] = None
+
+
+def set_torch_compile_backend(backend: Union[Callable, str]):
+    global _torch_compile_backend
+    _torch_compile_backend = backend
+
+
+def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
+    return _torch_compile_backend

+ 4 - 2
aphrodite/worker/model_runner.py

@@ -1069,13 +1069,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                     "provided. Defaulting to scaling factors of 1.0. "
                     "This may lead to less accurate results!")
 
-        if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo:
+        if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
             logger.info("Compiling the model using torch.compile...")
+            from aphrodite.plugins import get_torch_compile_backend
+            backend = get_torch_compile_backend() or "eager"
             start_time = time.time()
             self.model = torch.compile(
                 self.model,
                 fullgraph=envs.APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE,
-                backend="eager")
+                backend=backend)
             end_time = time.time()
             logger.info(
                 f"Model compiled in {end_time - start_time:.2f} seconds.")