Kaynağa Gözat

fix: torch.compile dynamo fix (#1122)

AlpinDale 1 ay önce
ebeveyn
işleme
ede17d5039
1 değiştirilmiş dosya ile 8 ekleme ve 4 silme
  1. 8 4
      aphrodite/quantization/utils/layer_utils.py

+ 8 - 4
aphrodite/quantization/utils/layer_utils.py

@@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str,
                       new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
 
     old = getattr(mod, name)
-    if old.dtype == new.dtype  and \
+    if type(old) is type(new) and old.dtype == new.dtype and \
         old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
         # If we can just update in-place to avoid re-registering
         #   can be faster if the underlying storage is the same
         update_tensor_inplace(old, new)
     else:
-        # Fallback re-register parameter
+        # Fallback re-register parameter, convert to Parameter if necessary
+        # this not only ensures we don't register a tensor as a parameter, but
+        # also ensures that all parameter subclasses get re-registered as
+        # parameters for `torch.compile` compatibility
         if not isinstance(new, torch.nn.Parameter):
-            new = torch.nn.Parameter(new)
-        mod.register_parameter(name, torch.nn.Parameter(new))
+            new = torch.nn.Parameter(new, requires_grad=False)
+        mod.register_parameter(name,
+                               torch.nn.Parameter(new, requires_grad=False))