12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import bisect
- import numpy as np
- import albumentations
- from PIL import Image
- from torch.utils.data import Dataset, ConcatDataset
- class ConcatDatasetWithIndex(ConcatDataset):
- """Modified from original pytorch code to return dataset idx"""
- def __getitem__(self, idx):
- if idx < 0:
- if -idx > len(self):
- raise ValueError("absolute value of index should not exceed dataset length")
- idx = len(self) + idx
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
- return self.datasets[dataset_idx][sample_idx], dataset_idx
- class ImagePaths(Dataset):
- def __init__(self, paths, size=None, random_crop=False, labels=None):
- self.size = size
- self.random_crop = random_crop
- self.labels = dict() if labels is None else labels
- self.labels["file_path_"] = paths
- self._length = len(paths)
- if self.size is not None and self.size > 0:
- self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
- if not self.random_crop:
- self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
- else:
- self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
- self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
- else:
- self.preprocessor = lambda **kwargs: kwargs
- def __len__(self):
- return self._length
- def preprocess_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- image = np.array(image).astype(np.uint8)
- image = self.preprocessor(image=image)["image"]
- image = (image/127.5 - 1.0).astype(np.float32)
- return image
- def __getitem__(self, i):
- example = dict()
- example["image"] = self.preprocess_image(self.labels["file_path_"][i])
- for k in self.labels:
- example[k] = self.labels[k][i]
- return example
- class NumpyPaths(ImagePaths):
- def preprocess_image(self, image_path):
- image = np.load(image_path).squeeze(0)
- image = np.transpose(image, (1,2,0))
- image = Image.fromarray(image, mode="RGB")
- image = np.array(image).astype(np.uint8)
- image = self.preprocessor(image=image)["image"]
- image = (image/127.5 - 1.0).astype(np.float32)
- return image
|