123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import os
- import numpy as np
- import albumentations
- from torch.utils.data import Dataset
- from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
- class FacesBase(Dataset):
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.data = None
- self.keys = None
- def __len__(self):
- return len(self.data)
- def __getitem__(self, i):
- example = self.data[i]
- ex = {}
- if self.keys is not None:
- for k in self.keys:
- ex[k] = example[k]
- else:
- ex = example
- return ex
- class CelebAHQTrain(FacesBase):
- def __init__(self, size, keys=None):
- super().__init__()
- root = "data/celebahq"
- with open("data/celebahqtrain.txt", "r") as f:
- relpaths = f.read().splitlines()
- paths = [os.path.join(root, relpath) for relpath in relpaths]
- self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
- self.keys = keys
- class CelebAHQValidation(FacesBase):
- def __init__(self, size, keys=None):
- super().__init__()
- root = "data/celebahq"
- with open("data/celebahqvalidation.txt", "r") as f:
- relpaths = f.read().splitlines()
- paths = [os.path.join(root, relpath) for relpath in relpaths]
- self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
- self.keys = keys
- class FFHQTrain(FacesBase):
- def __init__(self, size, keys=None):
- super().__init__()
- root = "data/ffhq"
- with open("data/ffhqtrain.txt", "r") as f:
- relpaths = f.read().splitlines()
- paths = [os.path.join(root, relpath) for relpath in relpaths]
- self.data = ImagePaths(paths=paths, size=size, random_crop=False)
- self.keys = keys
- class FFHQValidation(FacesBase):
- def __init__(self, size, keys=None):
- super().__init__()
- root = "data/ffhq"
- with open("data/ffhqvalidation.txt", "r") as f:
- relpaths = f.read().splitlines()
- paths = [os.path.join(root, relpath) for relpath in relpaths]
- self.data = ImagePaths(paths=paths, size=size, random_crop=False)
- self.keys = keys
- class FacesHQTrain(Dataset):
- # CelebAHQ [0] + FFHQ [1]
- def __init__(self, size, keys=None, crop_size=None, coord=False):
- d1 = CelebAHQTrain(size=size, keys=keys)
- d2 = FFHQTrain(size=size, keys=keys)
- self.data = ConcatDatasetWithIndex([d1, d2])
- self.coord = coord
- if crop_size is not None:
- self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
- if self.coord:
- self.cropper = albumentations.Compose([self.cropper],
- additional_targets={"coord": "image"})
- def __len__(self):
- return len(self.data)
- def __getitem__(self, i):
- ex, y = self.data[i]
- if hasattr(self, "cropper"):
- if not self.coord:
- out = self.cropper(image=ex["image"])
- ex["image"] = out["image"]
- else:
- h,w,_ = ex["image"].shape
- coord = np.arange(h*w).reshape(h,w,1)/(h*w)
- out = self.cropper(image=ex["image"], coord=coord)
- ex["image"] = out["image"]
- ex["coord"] = out["coord"]
- ex["class"] = y
- return ex
- class FacesHQValidation(Dataset):
- # CelebAHQ [0] + FFHQ [1]
- def __init__(self, size, keys=None, crop_size=None, coord=False):
- d1 = CelebAHQValidation(size=size, keys=keys)
- d2 = FFHQValidation(size=size, keys=keys)
- self.data = ConcatDatasetWithIndex([d1, d2])
- self.coord = coord
- if crop_size is not None:
- self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
- if self.coord:
- self.cropper = albumentations.Compose([self.cropper],
- additional_targets={"coord": "image"})
- def __len__(self):
- return len(self.data)
- def __getitem__(self, i):
- ex, y = self.data[i]
- if hasattr(self, "cropper"):
- if not self.coord:
- out = self.cropper(image=ex["image"])
- ex["image"] = out["image"]
- else:
- h,w,_ = ex["image"].shape
- coord = np.arange(h*w).reshape(h,w,1)/(h*w)
- out = self.cropper(image=ex["image"], coord=coord)
- ex["image"] = out["image"]
- ex["coord"] = out["coord"]
- ex["class"] = y
- return ex
|