Browse Source

fix circular reference with weakref

AlpinDale 8 months ago
parent
commit
16f345c29a

+ 5 - 5
aphrodite/modeling/model_loader/tensorizer.py

@@ -17,7 +17,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.quantization.base_config import QuantizationConfig
 
-tensorizer_load_fail = None
+tensorizer_error_msg = None
 
 try:
     from tensorizer import (DecryptionParams, EncryptionParams,
@@ -26,7 +26,7 @@ try:
     from tensorizer.utils import (convert_bytes, get_mem_usage,
                                   no_init_or_tensor)
 except ImportError as e:
-    tensorizer_load_fail = e
+    tensorizer_error_msg = e
 
 __all__ = [
     'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
@@ -255,12 +255,12 @@ class TensorizerAgent:
 
     def __init__(self, tensorizer_config: TensorizerConfig,
                  quant_config: QuantizationConfig, **extra_kwargs):
-        if tensorizer_load_fail is not None:
+        if tensorizer_error_msg is not None:
             raise ImportError(
                 "Tensorizer is not installed. Please install tensorizer "
                 "to use this feature with "
-                "`pip install aphrodite-engine[tensorizer]`."
-            ) from tensorizer_load_fail
+                "`pip install aphrodite-engine[tensorizer]`. "
+                "Error message: {}".format(tensorizer_error_msg))
 
         self.tensorizer_config = tensorizer_config
         self.tensorizer_args = (

+ 2 - 1
aphrodite/spec_decode/multi_step_worker.py

@@ -1,4 +1,5 @@
 import copy
+import weakref
 from typing import List, Tuple
 
 import torch
@@ -32,7 +33,7 @@ class MultiStepWorker(Worker):
         super().init_device()
 
         self._proposer = Top1Proposer(
-            self,
+            weakref.proxy(self),
             self.device,
             self.vocab_size,
             max_proposal_len=self.max_model_len,

+ 2 - 1
aphrodite/spec_decode/ngram_worker.py

@@ -1,3 +1,4 @@
+import weakref
 from typing import List, Optional, Tuple
 
 import torch
@@ -37,7 +38,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
 
         # Current only support Top1Proposer
         self._proposer = Top1Proposer(
-            self,
+            weakref.proxy(self),
             device=self.device,
             vocab_size=self.vocab_size,
         )