123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import collections
- import os
- import tarfile
- import urllib
- import zipfile
- from pathlib import Path
- import numpy as np
- import torch
- from taming.data.helper_types import Annotation
- from torch._six import string_classes
- from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
- from tqdm import tqdm
- def unpack(path):
- if path.endswith("tar.gz"):
- with tarfile.open(path, "r:gz") as tar:
- tar.extractall(path=os.path.split(path)[0])
- elif path.endswith("tar"):
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=os.path.split(path)[0])
- elif path.endswith("zip"):
- with zipfile.ZipFile(path, "r") as f:
- f.extractall(path=os.path.split(path)[0])
- else:
- raise NotImplementedError(
- "Unknown file extension: {}".format(os.path.splitext(path)[1])
- )
- def reporthook(bar):
- """tqdm progress bar for downloads."""
- def hook(b=1, bsize=1, tsize=None):
- if tsize is not None:
- bar.total = tsize
- bar.update(b * bsize - bar.n)
- return hook
- def get_root(name):
- base = "data/"
- root = os.path.join(base, name)
- os.makedirs(root, exist_ok=True)
- return root
- def is_prepared(root):
- return Path(root).joinpath(".ready").exists()
- def mark_prepared(root):
- Path(root).joinpath(".ready").touch()
- def prompt_download(file_, source, target_dir, content_dir=None):
- targetpath = os.path.join(target_dir, file_)
- while not os.path.exists(targetpath):
- if content_dir is not None and os.path.exists(
- os.path.join(target_dir, content_dir)
- ):
- break
- print(
- "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
- )
- if content_dir is not None:
- print(
- "Or place its content into '{}'.".format(
- os.path.join(target_dir, content_dir)
- )
- )
- input("Press Enter when done...")
- return targetpath
- def download_url(file_, url, target_dir):
- targetpath = os.path.join(target_dir, file_)
- os.makedirs(target_dir, exist_ok=True)
- with tqdm(
- unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
- ) as bar:
- urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
- return targetpath
- def download_urls(urls, target_dir):
- paths = dict()
- for fname, url in urls.items():
- outpath = download_url(fname, url, target_dir)
- paths[fname] = outpath
- return paths
- def quadratic_crop(x, bbox, alpha=1.0):
- """bbox is xmin, ymin, xmax, ymax"""
- im_h, im_w = x.shape[:2]
- bbox = np.array(bbox, dtype=np.float32)
- bbox = np.clip(bbox, 0, max(im_h, im_w))
- center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
- w = bbox[2] - bbox[0]
- h = bbox[3] - bbox[1]
- l = int(alpha * max(w, h))
- l = max(l, 2)
- required_padding = -1 * min(
- center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
- )
- required_padding = int(np.ceil(required_padding))
- if required_padding > 0:
- padding = [
- [required_padding, required_padding],
- [required_padding, required_padding],
- ]
- padding += [[0, 0]] * (len(x.shape) - 2)
- x = np.pad(x, padding, "reflect")
- center = center[0] + required_padding, center[1] + required_padding
- xmin = int(center[0] - l / 2)
- ymin = int(center[1] - l / 2)
- return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
- def custom_collate(batch):
- r"""source: pytorch 1.9.0, only one modification to original code """
- elem = batch[0]
- elem_type = type(elem)
- if isinstance(elem, torch.Tensor):
- out = None
- if torch.utils.data.get_worker_info() is not None:
- # If we're in a background process, concatenate directly into a
- # shared memory tensor to avoid an extra copy
- numel = sum([x.numel() for x in batch])
- storage = elem.storage()._new_shared(numel)
- out = elem.new(storage)
- return torch.stack(batch, 0, out=out)
- elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
- and elem_type.__name__ != 'string_':
- if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
- # array of string classes and object
- if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
- raise TypeError(default_collate_err_msg_format.format(elem.dtype))
- return custom_collate([torch.as_tensor(b) for b in batch])
- elif elem.shape == (): # scalars
- return torch.as_tensor(batch)
- elif isinstance(elem, float):
- return torch.tensor(batch, dtype=torch.float64)
- elif isinstance(elem, int):
- return torch.tensor(batch)
- elif isinstance(elem, string_classes):
- return batch
- elif isinstance(elem, collections.abc.Mapping):
- return {key: custom_collate([d[key] for d in batch]) for key in elem}
- elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
- return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
- if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
- return batch # added
- elif isinstance(elem, collections.abc.Sequence):
- # check to make sure that the elements in batch have consistent size
- it = iter(batch)
- elem_size = len(next(it))
- if not all(len(elem) == elem_size for elem in it):
- raise RuntimeError('each element in list of batch should be of equal size')
- transposed = zip(*batch)
- return [custom_collate(samples) for samples in transposed]
- raise TypeError(default_collate_err_msg_format.format(elem_type))
|