123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558 |
- import os, tarfile, glob, shutil
- import yaml
- import numpy as np
- from tqdm import tqdm
- from PIL import Image
- import albumentations
- from omegaconf import OmegaConf
- from torch.utils.data import Dataset
- from taming.data.base import ImagePaths
- from taming.util import download, retrieve
- import taming.data.utils as bdu
- def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
- synsets = []
- with open(path_to_yaml) as f:
- di2s = yaml.load(f)
- for idx in indices:
- synsets.append(str(di2s[idx]))
- print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
- return synsets
- def str_to_indices(string):
- """Expects a string in the format '32-123, 256, 280-321'"""
- assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
- subs = string.split(",")
- indices = []
- for sub in subs:
- subsubs = sub.split("-")
- assert len(subsubs) > 0
- if len(subsubs) == 1:
- indices.append(int(subsubs[0]))
- else:
- rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
- indices.extend(rang)
- return sorted(indices)
- class ImageNetBase(Dataset):
- def __init__(self, config=None):
- self.config = config or OmegaConf.create()
- if not type(self.config)==dict:
- self.config = OmegaConf.to_container(self.config)
- self._prepare()
- self._prepare_synset_to_human()
- self._prepare_idx_to_synset()
- self._load()
- def __len__(self):
- return len(self.data)
- def __getitem__(self, i):
- return self.data[i]
- def _prepare(self):
- raise NotImplementedError()
- def _filter_relpaths(self, relpaths):
- ignore = set([
- "n06596364_9591.JPEG",
- ])
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
- if "sub_indices" in self.config:
- indices = str_to_indices(self.config["sub_indices"])
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
- files = []
- for rpath in relpaths:
- syn = rpath.split("/")[0]
- if syn in synsets:
- files.append(rpath)
- return files
- else:
- return relpaths
- def _prepare_synset_to_human(self):
- SIZE = 2655750
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
- self.human_dict = os.path.join(self.root, "synset_human.txt")
- if (not os.path.exists(self.human_dict) or
- not os.path.getsize(self.human_dict)==SIZE):
- download(URL, self.human_dict)
- def _prepare_idx_to_synset(self):
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
- if (not os.path.exists(self.idx2syn)):
- download(URL, self.idx2syn)
- def _load(self):
- with open(self.txt_filelist, "r") as f:
- self.relpaths = f.read().splitlines()
- l1 = len(self.relpaths)
- self.relpaths = self._filter_relpaths(self.relpaths)
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
- self.synsets = [p.split("/")[0] for p in self.relpaths]
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
- unique_synsets = np.unique(self.synsets)
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
- self.class_labels = [class_dict[s] for s in self.synsets]
- with open(self.human_dict, "r") as f:
- human_dict = f.read().splitlines()
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
- self.human_labels = [human_dict[s] for s in self.synsets]
- labels = {
- "relpath": np.array(self.relpaths),
- "synsets": np.array(self.synsets),
- "class_label": np.array(self.class_labels),
- "human_label": np.array(self.human_labels),
- }
- self.data = ImagePaths(self.abspaths,
- labels=labels,
- size=retrieve(self.config, "size", default=0),
- random_crop=self.random_crop)
- class ImageNetTrain(ImageNetBase):
- NAME = "ILSVRC2012_train"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
- FILES = [
- "ILSVRC2012_img_train.tar",
- ]
- SIZES = [
- 147897477120,
- ]
- def _prepare(self):
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
- default=True)
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 1281167
- if not bdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
- print("Extracting sub-tars.")
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
- for subpath in tqdm(subpaths):
- subdir = subpath[:-len(".tar")]
- os.makedirs(subdir, exist_ok=True)
- with tarfile.open(subpath, "r:") as tar:
- tar.extractall(path=subdir)
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
- bdu.mark_prepared(self.root)
- class ImageNetValidation(ImageNetBase):
- NAME = "ILSVRC2012_validation"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
- FILES = [
- "ILSVRC2012_img_val.tar",
- "validation_synset.txt",
- ]
- SIZES = [
- 6744924160,
- 1950000,
- ]
- def _prepare(self):
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
- default=False)
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 50000
- if not bdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
- vspath = os.path.join(self.root, self.FILES[1])
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
- download(self.VS_URL, vspath)
- with open(vspath, "r") as f:
- synset_dict = f.read().splitlines()
- synset_dict = dict(line.split() for line in synset_dict)
- print("Reorganizing into synset folders")
- synsets = np.unique(list(synset_dict.values()))
- for s in synsets:
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
- for k, v in synset_dict.items():
- src = os.path.join(datadir, k)
- dst = os.path.join(datadir, v)
- shutil.move(src, dst)
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
- bdu.mark_prepared(self.root)
- def get_preprocessor(size=None, random_crop=False, additional_targets=None,
- crop_size=None):
- if size is not None and size > 0:
- transforms = list()
- rescaler = albumentations.SmallestMaxSize(max_size = size)
- transforms.append(rescaler)
- if not random_crop:
- cropper = albumentations.CenterCrop(height=size,width=size)
- transforms.append(cropper)
- else:
- cropper = albumentations.RandomCrop(height=size,width=size)
- transforms.append(cropper)
- flipper = albumentations.HorizontalFlip()
- transforms.append(flipper)
- preprocessor = albumentations.Compose(transforms,
- additional_targets=additional_targets)
- elif crop_size is not None and crop_size > 0:
- if not random_crop:
- cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
- else:
- cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
- transforms = [cropper]
- preprocessor = albumentations.Compose(transforms,
- additional_targets=additional_targets)
- else:
- preprocessor = lambda **kwargs: kwargs
- return preprocessor
- def rgba_to_depth(x):
- assert x.dtype == np.uint8
- assert len(x.shape) == 3 and x.shape[2] == 4
- y = x.copy()
- y.dtype = np.float32
- y = y.reshape(x.shape[:2])
- return np.ascontiguousarray(y)
- class BaseWithDepth(Dataset):
- DEFAULT_DEPTH_ROOT="data/imagenet_depth"
- def __init__(self, config=None, size=None, random_crop=False,
- crop_size=None, root=None):
- self.config = config
- self.base_dset = self.get_base_dset()
- self.preprocessor = get_preprocessor(
- size=size,
- crop_size=crop_size,
- random_crop=random_crop,
- additional_targets={"depth": "image"})
- self.crop_size = crop_size
- if self.crop_size is not None:
- self.rescaler = albumentations.Compose(
- [albumentations.SmallestMaxSize(max_size = self.crop_size)],
- additional_targets={"depth": "image"})
- if root is not None:
- self.DEFAULT_DEPTH_ROOT = root
- def __len__(self):
- return len(self.base_dset)
- def preprocess_depth(self, path):
- rgba = np.array(Image.open(path))
- depth = rgba_to_depth(rgba)
- depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
- depth = 2.0*depth-1.0
- return depth
- def __getitem__(self, i):
- e = self.base_dset[i]
- e["depth"] = self.preprocess_depth(self.get_depth_path(e))
- # up if necessary
- h,w,c = e["image"].shape
- if self.crop_size and min(h,w) < self.crop_size:
- # have to upscale to be able to crop - this just uses bilinear
- out = self.rescaler(image=e["image"], depth=e["depth"])
- e["image"] = out["image"]
- e["depth"] = out["depth"]
- transformed = self.preprocessor(image=e["image"], depth=e["depth"])
- e["image"] = transformed["image"]
- e["depth"] = transformed["depth"]
- return e
- class ImageNetTrainWithDepth(BaseWithDepth):
- # default to random_crop=True
- def __init__(self, random_crop=True, sub_indices=None, **kwargs):
- self.sub_indices = sub_indices
- super().__init__(random_crop=random_crop, **kwargs)
- def get_base_dset(self):
- if self.sub_indices is None:
- return ImageNetTrain()
- else:
- return ImageNetTrain({"sub_indices": self.sub_indices})
- def get_depth_path(self, e):
- fid = os.path.splitext(e["relpath"])[0]+".png"
- fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
- return fid
- class ImageNetValidationWithDepth(BaseWithDepth):
- def __init__(self, sub_indices=None, **kwargs):
- self.sub_indices = sub_indices
- super().__init__(**kwargs)
- def get_base_dset(self):
- if self.sub_indices is None:
- return ImageNetValidation()
- else:
- return ImageNetValidation({"sub_indices": self.sub_indices})
- def get_depth_path(self, e):
- fid = os.path.splitext(e["relpath"])[0]+".png"
- fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
- return fid
- class RINTrainWithDepth(ImageNetTrainWithDepth):
- def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
- sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
- super().__init__(config=config, size=size, random_crop=random_crop,
- sub_indices=sub_indices, crop_size=crop_size)
- class RINValidationWithDepth(ImageNetValidationWithDepth):
- def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
- sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
- super().__init__(config=config, size=size, random_crop=random_crop,
- sub_indices=sub_indices, crop_size=crop_size)
- class DRINExamples(Dataset):
- def __init__(self):
- self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
- with open("data/drin_examples.txt", "r") as f:
- relpaths = f.read().splitlines()
- self.image_paths = [os.path.join("data/drin_images",
- relpath) for relpath in relpaths]
- self.depth_paths = [os.path.join("data/drin_depth",
- relpath.replace(".JPEG", ".png")) for relpath in relpaths]
- def __len__(self):
- return len(self.image_paths)
- def preprocess_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- image = np.array(image).astype(np.uint8)
- image = self.preprocessor(image=image)["image"]
- image = (image/127.5 - 1.0).astype(np.float32)
- return image
- def preprocess_depth(self, path):
- rgba = np.array(Image.open(path))
- depth = rgba_to_depth(rgba)
- depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
- depth = 2.0*depth-1.0
- return depth
- def __getitem__(self, i):
- e = dict()
- e["image"] = self.preprocess_image(self.image_paths[i])
- e["depth"] = self.preprocess_depth(self.depth_paths[i])
- transformed = self.preprocessor(image=e["image"], depth=e["depth"])
- e["image"] = transformed["image"]
- e["depth"] = transformed["depth"]
- return e
- def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
- if factor is None or factor==1:
- return x
- dtype = x.dtype
- assert dtype in [np.float32, np.float64]
- assert x.min() >= -1
- assert x.max() <= 1
- keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC}[keepmode]
- lr = (x+1.0)*127.5
- lr = lr.clip(0,255).astype(np.uint8)
- lr = Image.fromarray(lr)
- h, w, _ = x.shape
- nh = h//factor
- nw = w//factor
- assert nh > 0 and nw > 0, (nh, nw)
- lr = lr.resize((nw,nh), Image.BICUBIC)
- if keepshapes:
- lr = lr.resize((w,h), keepmode)
- lr = np.array(lr)/127.5-1.0
- lr = lr.astype(dtype)
- return lr
- class ImageNetScale(Dataset):
- def __init__(self, size=None, crop_size=None, random_crop=False,
- up_factor=None, hr_factor=None, keep_mode="bicubic"):
- self.base = self.get_base()
- self.size = size
- self.crop_size = crop_size if crop_size is not None else self.size
- self.random_crop = random_crop
- self.up_factor = up_factor
- self.hr_factor = hr_factor
- self.keep_mode = keep_mode
- transforms = list()
- if self.size is not None and self.size > 0:
- rescaler = albumentations.SmallestMaxSize(max_size = self.size)
- self.rescaler = rescaler
- transforms.append(rescaler)
- if self.crop_size is not None and self.crop_size > 0:
- if len(transforms) == 0:
- self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
- if not self.random_crop:
- cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
- else:
- cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
- transforms.append(cropper)
- if len(transforms) > 0:
- if self.up_factor is not None:
- additional_targets = {"lr": "image"}
- else:
- additional_targets = None
- self.preprocessor = albumentations.Compose(transforms,
- additional_targets=additional_targets)
- else:
- self.preprocessor = lambda **kwargs: kwargs
- def __len__(self):
- return len(self.base)
- def __getitem__(self, i):
- example = self.base[i]
- image = example["image"]
- # adjust resolution
- image = imscale(image, self.hr_factor, keepshapes=False)
- h,w,c = image.shape
- if self.crop_size and min(h,w) < self.crop_size:
- # have to upscale to be able to crop - this just uses bilinear
- image = self.rescaler(image=image)["image"]
- if self.up_factor is None:
- image = self.preprocessor(image=image)["image"]
- example["image"] = image
- else:
- lr = imscale(image, self.up_factor, keepshapes=True,
- keepmode=self.keep_mode)
- out = self.preprocessor(image=image, lr=lr)
- example["image"] = out["image"]
- example["lr"] = out["lr"]
- return example
- class ImageNetScaleTrain(ImageNetScale):
- def __init__(self, random_crop=True, **kwargs):
- super().__init__(random_crop=random_crop, **kwargs)
- def get_base(self):
- return ImageNetTrain()
- class ImageNetScaleValidation(ImageNetScale):
- def get_base(self):
- return ImageNetValidation()
- from skimage.feature import canny
- from skimage.color import rgb2gray
- class ImageNetEdges(ImageNetScale):
- def __init__(self, up_factor=1, **kwargs):
- super().__init__(up_factor=1, **kwargs)
- def __getitem__(self, i):
- example = self.base[i]
- image = example["image"]
- h,w,c = image.shape
- if self.crop_size and min(h,w) < self.crop_size:
- # have to upscale to be able to crop - this just uses bilinear
- image = self.rescaler(image=image)["image"]
- lr = canny(rgb2gray(image), sigma=2)
- lr = lr.astype(np.float32)
- lr = lr[:,:,None][:,:,[0,0,0]]
- out = self.preprocessor(image=image, lr=lr)
- example["image"] = out["image"]
- example["lr"] = out["lr"]
- return example
- class ImageNetEdgesTrain(ImageNetEdges):
- def __init__(self, random_crop=True, **kwargs):
- super().__init__(random_crop=random_crop, **kwargs)
- def get_base(self):
- return ImageNetTrain()
- class ImageNetEdgesValidation(ImageNetEdges):
- def get_base(self):
- return ImageNetValidation()