Ver código fonte

[Training] Fix lightning _PATH import

Tri Dao 2 anos atrás
pai
commit
009a3e71ec
2 arquivos alterados com 14 adições e 4 exclusões
  1. 7 2
      training/src/utils/ddp_zero1.py
  2. 7 2
      training/src/utils/ddp_zero2.py

+ 7 - 2
training/src/utils/ddp_zero1.py

@@ -9,8 +9,13 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
 
 from pytorch_lightning.strategies.ddp import DDPStrategy
 from pytorch_lightning.core.optimizer import LightningOptimizer
-from pytorch_lightning.utilities.types import _PATH
-# from lightning_lite.utilities.types import _PATH
+try:  # pytorch_lightning <= 1.7
+    from pytorch_lightning.utilities.types import _PATH
+except ImportError:  # pytorch_lightning >= 1.8
+    try:
+        from lightning_lite.utilities.types import _PATH
+    except ImportError:  # pytorch_lightning >= 1.9
+        from lightning_fabric.utilities.types import _PATH
 
 
 # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get

+ 7 - 2
training/src/utils/ddp_zero2.py

@@ -13,9 +13,14 @@ from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
 from pytorch_lightning.strategies.ddp import DDPStrategy
 from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin
 from pytorch_lightning.core.optimizer import LightningOptimizer
-from pytorch_lightning.utilities.types import _PATH
-# from lightning_lite.utilities.types import _PATH
 from pytorch_lightning.utilities.exceptions import MisconfigurationException
+try:  # pytorch_lightning <= 1.7
+    from pytorch_lightning.utilities.types import _PATH
+except ImportError:  # pytorch_lightning >= 1.8
+    try:
+        from lightning_lite.utilities.types import _PATH
+    except ImportError:  # pytorch_lightning >= 1.9
+        from lightning_fabric.utilities.types import _PATH
 
 
 class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):