ade20k.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. import numpy as np
  3. import cv2
  4. import albumentations
  5. from PIL import Image
  6. from torch.utils.data import Dataset
  7. from taming.data.sflckr import SegmentationBase # for examples included in repo
  8. class Examples(SegmentationBase):
  9. def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
  10. super().__init__(data_csv="data/ade20k_examples.txt",
  11. data_root="data/ade20k_images",
  12. segmentation_root="data/ade20k_segmentations",
  13. size=size, random_crop=random_crop,
  14. interpolation=interpolation,
  15. n_labels=151, shift_segmentation=False)
  16. # With semantic map and scene label
  17. class ADE20kBase(Dataset):
  18. def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
  19. self.split = self.get_split()
  20. self.n_labels = 151 # unknown + 150
  21. self.data_csv = {"train": "data/ade20k_train.txt",
  22. "validation": "data/ade20k_test.txt"}[self.split]
  23. self.data_root = "data/ade20k_root"
  24. with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
  25. self.scene_categories = f.read().splitlines()
  26. self.scene_categories = dict(line.split() for line in self.scene_categories)
  27. with open(self.data_csv, "r") as f:
  28. self.image_paths = f.read().splitlines()
  29. self._length = len(self.image_paths)
  30. self.labels = {
  31. "relative_file_path_": [l for l in self.image_paths],
  32. "file_path_": [os.path.join(self.data_root, "images", l)
  33. for l in self.image_paths],
  34. "relative_segmentation_path_": [l.replace(".jpg", ".png")
  35. for l in self.image_paths],
  36. "segmentation_path_": [os.path.join(self.data_root, "annotations",
  37. l.replace(".jpg", ".png"))
  38. for l in self.image_paths],
  39. "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
  40. for l in self.image_paths],
  41. }
  42. size = None if size is not None and size<=0 else size
  43. self.size = size
  44. if crop_size is None:
  45. self.crop_size = size if size is not None else None
  46. else:
  47. self.crop_size = crop_size
  48. if self.size is not None:
  49. self.interpolation = interpolation
  50. self.interpolation = {
  51. "nearest": cv2.INTER_NEAREST,
  52. "bilinear": cv2.INTER_LINEAR,
  53. "bicubic": cv2.INTER_CUBIC,
  54. "area": cv2.INTER_AREA,
  55. "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
  56. self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
  57. interpolation=self.interpolation)
  58. self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
  59. interpolation=cv2.INTER_NEAREST)
  60. if crop_size is not None:
  61. self.center_crop = not random_crop
  62. if self.center_crop:
  63. self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
  64. else:
  65. self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
  66. self.preprocessor = self.cropper
  67. def __len__(self):
  68. return self._length
  69. def __getitem__(self, i):
  70. example = dict((k, self.labels[k][i]) for k in self.labels)
  71. image = Image.open(example["file_path_"])
  72. if not image.mode == "RGB":
  73. image = image.convert("RGB")
  74. image = np.array(image).astype(np.uint8)
  75. if self.size is not None:
  76. image = self.image_rescaler(image=image)["image"]
  77. segmentation = Image.open(example["segmentation_path_"])
  78. segmentation = np.array(segmentation).astype(np.uint8)
  79. if self.size is not None:
  80. segmentation = self.segmentation_rescaler(image=segmentation)["image"]
  81. if self.size is not None:
  82. processed = self.preprocessor(image=image, mask=segmentation)
  83. else:
  84. processed = {"image": image, "mask": segmentation}
  85. example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
  86. segmentation = processed["mask"]
  87. onehot = np.eye(self.n_labels)[segmentation]
  88. example["segmentation"] = onehot
  89. return example
  90. class ADE20kTrain(ADE20kBase):
  91. # default to random_crop=True
  92. def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
  93. super().__init__(config=config, size=size, random_crop=random_crop,
  94. interpolation=interpolation, crop_size=crop_size)
  95. def get_split(self):
  96. return "train"
  97. class ADE20kValidation(ADE20kBase):
  98. def get_split(self):
  99. return "validation"
  100. if __name__ == "__main__":
  101. dset = ADE20kValidation()
  102. ex = dset[0]
  103. for k in ["image", "scene_category", "segmentation"]:
  104. print(type(ex[k]))
  105. try:
  106. print(ex[k].shape)
  107. except:
  108. print(ex[k])