|
@@ -13,9 +13,14 @@ from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
|
|
|
from pytorch_lightning.strategies.ddp import DDPStrategy
|
|
|
from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin
|
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
|
-from pytorch_lightning.utilities.types import _PATH
|
|
|
-# from lightning_lite.utilities.types import _PATH
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
+try: # pytorch_lightning <= 1.7
|
|
|
+ from pytorch_lightning.utilities.types import _PATH
|
|
|
+except ImportError: # pytorch_lightning >= 1.8
|
|
|
+ try:
|
|
|
+ from lightning_lite.utilities.types import _PATH
|
|
|
+ except ImportError: # pytorch_lightning >= 1.9
|
|
|
+ from lightning_fabric.utilities.types import _PATH
|
|
|
|
|
|
|
|
|
class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|