2
0

util.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os, hashlib
  2. import requests
  3. from tqdm import tqdm
  4. URL_MAP = {
  5. "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
  6. }
  7. CKPT_MAP = {
  8. "vgg_lpips": "vgg.pth"
  9. }
  10. MD5_MAP = {
  11. "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
  12. }
  13. def download(url, local_path, chunk_size=1024):
  14. os.makedirs(os.path.split(local_path)[0], exist_ok=True)
  15. with requests.get(url, stream=True) as r:
  16. total_size = int(r.headers.get("content-length", 0))
  17. with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
  18. with open(local_path, "wb") as f:
  19. for data in r.iter_content(chunk_size=chunk_size):
  20. if data:
  21. f.write(data)
  22. pbar.update(chunk_size)
  23. def md5_hash(path):
  24. with open(path, "rb") as f:
  25. content = f.read()
  26. return hashlib.md5(content).hexdigest()
  27. def get_ckpt_path(name, root, check=False):
  28. assert name in URL_MAP
  29. path = os.path.join(root, CKPT_MAP[name])
  30. if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
  31. print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
  32. download(URL_MAP[name], path)
  33. md5 = md5_hash(path)
  34. assert md5 == MD5_MAP[name], md5
  35. return path
  36. class KeyNotFoundError(Exception):
  37. def __init__(self, cause, keys=None, visited=None):
  38. self.cause = cause
  39. self.keys = keys
  40. self.visited = visited
  41. messages = list()
  42. if keys is not None:
  43. messages.append("Key not found: {}".format(keys))
  44. if visited is not None:
  45. messages.append("Visited: {}".format(visited))
  46. messages.append("Cause:\n{}".format(cause))
  47. message = "\n".join(messages)
  48. super().__init__(message)
  49. def retrieve(
  50. list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
  51. ):
  52. """Given a nested list or dict return the desired value at key expanding
  53. callable nodes if necessary and :attr:`expand` is ``True``. The expansion
  54. is done in-place.
  55. Parameters
  56. ----------
  57. list_or_dict : list or dict
  58. Possibly nested list or dictionary.
  59. key : str
  60. key/to/value, path like string describing all keys necessary to
  61. consider to get to the desired value. List indices can also be
  62. passed here.
  63. splitval : str
  64. String that defines the delimiter between keys of the
  65. different depth levels in `key`.
  66. default : obj
  67. Value returned if :attr:`key` is not found.
  68. expand : bool
  69. Whether to expand callable nodes on the path or not.
  70. Returns
  71. -------
  72. The desired value or if :attr:`default` is not ``None`` and the
  73. :attr:`key` is not found returns ``default``.
  74. Raises
  75. ------
  76. Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
  77. ``None``.
  78. """
  79. keys = key.split(splitval)
  80. success = True
  81. try:
  82. visited = []
  83. parent = None
  84. last_key = None
  85. for key in keys:
  86. if callable(list_or_dict):
  87. if not expand:
  88. raise KeyNotFoundError(
  89. ValueError(
  90. "Trying to get past callable node with expand=False."
  91. ),
  92. keys=keys,
  93. visited=visited,
  94. )
  95. list_or_dict = list_or_dict()
  96. parent[last_key] = list_or_dict
  97. last_key = key
  98. parent = list_or_dict
  99. try:
  100. if isinstance(list_or_dict, dict):
  101. list_or_dict = list_or_dict[key]
  102. else:
  103. list_or_dict = list_or_dict[int(key)]
  104. except (KeyError, IndexError, ValueError) as e:
  105. raise KeyNotFoundError(e, keys=keys, visited=visited)
  106. visited += [key]
  107. # final expansion of retrieved value
  108. if expand and callable(list_or_dict):
  109. list_or_dict = list_or_dict()
  110. parent[last_key] = list_or_dict
  111. except KeyNotFoundError as e:
  112. if default is None:
  113. raise e
  114. else:
  115. list_or_dict = default
  116. success = False
  117. if not pass_success:
  118. return list_or_dict
  119. else:
  120. return list_or_dict, success
  121. if __name__ == "__main__":
  122. config = {"keya": "a",
  123. "keyb": "b",
  124. "keyc":
  125. {"cc1": 1,
  126. "cc2": 2,
  127. }
  128. }
  129. from omegaconf import OmegaConf
  130. config = OmegaConf.create(config)
  131. print(config)
  132. retrieve(config, "keya")