1
0

run.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Callable
  2. import dotenv
  3. import hydra
  4. from omegaconf import OmegaConf, DictConfig
  5. # load environment variables from `.env` file if it exists
  6. # recursively searches for `.env` in all folders starting from work dir
  7. dotenv.load_dotenv(override=True)
  8. OmegaConf.register_new_resolver('eval', eval)
  9. OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
  10. # Delay the evaluation until we have the datamodule
  11. # So we want the resolver to yield the same string.
  12. OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}')
  13. # Turn on TensorFloat32
  14. import torch.backends
  15. torch.backends.cuda.matmul.allow_tf32 = True
  16. torch.backends.cudnn.allow_tf32 = True
  17. def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig:
  18. """Only keep keys where fn(key) is True. Support nested DictConfig.
  19. """
  20. # Using d.items_ex(resolve=False) instead of d.items() since we want to keep the
  21. # ${datamodule:foo} unresolved for now.
  22. return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v
  23. # for k, v in d.items_ex(resolve=False) if fn(k)})
  24. for k, v in d.items() if fn(k)})
  25. @hydra.main(config_path="configs/", config_name="config.yaml")
  26. def main(config: DictConfig):
  27. # Remove config keys that start with '__'. These are meant to be used only in computing
  28. # other entries in the config.
  29. config = dictconfig_filter_key(config, lambda k: not k.startswith('__'))
  30. # Imports should be nested inside @hydra.main to optimize tab completion
  31. # Read more here: https://github.com/facebookresearch/hydra/issues/934
  32. from src.train import train
  33. from src.eval import evaluate
  34. from src.utils import utils
  35. # A couple of optional utilities:
  36. # - disabling python warnings
  37. # - forcing debug-friendly configuration
  38. # - verifying experiment name is set when running in experiment mode
  39. # You can safely get rid of this line if you don't want those
  40. utils.extras(config)
  41. # Pretty print config using Rich library
  42. if config.get("print_config"):
  43. utils.print_config(config, resolve=True)
  44. # Train model
  45. mode = config.get('mode', 'train')
  46. if mode not in ['train', 'eval']:
  47. raise NotImplementedError(f'mode {mode} not supported')
  48. if mode == 'train':
  49. return train(config)
  50. elif mode == 'eval':
  51. return evaluate(config)
  52. if __name__ == "__main__":
  53. main()