base.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import bisect
  2. import numpy as np
  3. import albumentations
  4. from PIL import Image
  5. from torch.utils.data import Dataset, ConcatDataset
  6. class ConcatDatasetWithIndex(ConcatDataset):
  7. """Modified from original pytorch code to return dataset idx"""
  8. def __getitem__(self, idx):
  9. if idx < 0:
  10. if -idx > len(self):
  11. raise ValueError("absolute value of index should not exceed dataset length")
  12. idx = len(self) + idx
  13. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  14. if dataset_idx == 0:
  15. sample_idx = idx
  16. else:
  17. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  18. return self.datasets[dataset_idx][sample_idx], dataset_idx
  19. class ImagePaths(Dataset):
  20. def __init__(self, paths, size=None, random_crop=False, labels=None):
  21. self.size = size
  22. self.random_crop = random_crop
  23. self.labels = dict() if labels is None else labels
  24. self.labels["file_path_"] = paths
  25. self._length = len(paths)
  26. if self.size is not None and self.size > 0:
  27. self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
  28. if not self.random_crop:
  29. self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
  30. else:
  31. self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
  32. self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
  33. else:
  34. self.preprocessor = lambda **kwargs: kwargs
  35. def __len__(self):
  36. return self._length
  37. def preprocess_image(self, image_path):
  38. image = Image.open(image_path)
  39. if not image.mode == "RGB":
  40. image = image.convert("RGB")
  41. image = np.array(image).astype(np.uint8)
  42. image = self.preprocessor(image=image)["image"]
  43. image = (image/127.5 - 1.0).astype(np.float32)
  44. return image
  45. def __getitem__(self, i):
  46. example = dict()
  47. example["image"] = self.preprocess_image(self.labels["file_path_"][i])
  48. for k in self.labels:
  49. example[k] = self.labels[k][i]
  50. return example
  51. class NumpyPaths(ImagePaths):
  52. def preprocess_image(self, image_path):
  53. image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
  54. image = np.transpose(image, (1,2,0))
  55. image = Image.fromarray(image, mode="RGB")
  56. image = np.array(image).astype(np.uint8)
  57. image = self.preprocessor(image=image)["image"]
  58. image = (image/127.5 - 1.0).astype(np.float32)
  59. return image