123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import os, hashlib
- import requests
- from tqdm import tqdm
- URL_MAP = {
- "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
- }
- CKPT_MAP = {
- "vgg_lpips": "vgg.pth"
- }
- MD5_MAP = {
- "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
- }
- def download(url, local_path, chunk_size=1024):
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
- with requests.get(url, stream=True) as r:
- total_size = int(r.headers.get("content-length", 0))
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
- with open(local_path, "wb") as f:
- for data in r.iter_content(chunk_size=chunk_size):
- if data:
- f.write(data)
- pbar.update(chunk_size)
- def md5_hash(path):
- with open(path, "rb") as f:
- content = f.read()
- return hashlib.md5(content).hexdigest()
- def get_ckpt_path(name, root, check=False):
- assert name in URL_MAP
- path = os.path.join(root, CKPT_MAP[name])
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
- download(URL_MAP[name], path)
- md5 = md5_hash(path)
- assert md5 == MD5_MAP[name], md5
- return path
- class KeyNotFoundError(Exception):
- def __init__(self, cause, keys=None, visited=None):
- self.cause = cause
- self.keys = keys
- self.visited = visited
- messages = list()
- if keys is not None:
- messages.append("Key not found: {}".format(keys))
- if visited is not None:
- messages.append("Visited: {}".format(visited))
- messages.append("Cause:\n{}".format(cause))
- message = "\n".join(messages)
- super().__init__(message)
- def retrieve(
- list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
- ):
- """Given a nested list or dict return the desired value at key expanding
- callable nodes if necessary and :attr:`expand` is ``True``. The expansion
- is done in-place.
- Parameters
- ----------
- list_or_dict : list or dict
- Possibly nested list or dictionary.
- key : str
- key/to/value, path like string describing all keys necessary to
- consider to get to the desired value. List indices can also be
- passed here.
- splitval : str
- String that defines the delimiter between keys of the
- different depth levels in `key`.
- default : obj
- Value returned if :attr:`key` is not found.
- expand : bool
- Whether to expand callable nodes on the path or not.
- Returns
- -------
- The desired value or if :attr:`default` is not ``None`` and the
- :attr:`key` is not found returns ``default``.
- Raises
- ------
- Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
- ``None``.
- """
- keys = key.split(splitval)
- success = True
- try:
- visited = []
- parent = None
- last_key = None
- for key in keys:
- if callable(list_or_dict):
- if not expand:
- raise KeyNotFoundError(
- ValueError(
- "Trying to get past callable node with expand=False."
- ),
- keys=keys,
- visited=visited,
- )
- list_or_dict = list_or_dict()
- parent[last_key] = list_or_dict
- last_key = key
- parent = list_or_dict
- try:
- if isinstance(list_or_dict, dict):
- list_or_dict = list_or_dict[key]
- else:
- list_or_dict = list_or_dict[int(key)]
- except (KeyError, IndexError, ValueError) as e:
- raise KeyNotFoundError(e, keys=keys, visited=visited)
- visited += [key]
- # final expansion of retrieved value
- if expand and callable(list_or_dict):
- list_or_dict = list_or_dict()
- parent[last_key] = list_or_dict
- except KeyNotFoundError as e:
- if default is None:
- raise e
- else:
- list_or_dict = default
- success = False
- if not pass_success:
- return list_or_dict
- else:
- return list_or_dict, success
- if __name__ == "__main__":
- config = {"keya": "a",
- "keyb": "b",
- "keyc":
- {"cc1": 1,
- "cc2": 2,
- }
- }
- from omegaconf import OmegaConf
- config = OmegaConf.create(config)
- print(config)
- retrieve(config, "keya")
|