sflckr.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import numpy as np
  3. import cv2
  4. import albumentations
  5. from PIL import Image
  6. from torch.utils.data import Dataset
  7. class SegmentationBase(Dataset):
  8. def __init__(self,
  9. data_csv, data_root, segmentation_root,
  10. size=None, random_crop=False, interpolation="bicubic",
  11. n_labels=182, shift_segmentation=False,
  12. ):
  13. self.n_labels = n_labels
  14. self.shift_segmentation = shift_segmentation
  15. self.data_csv = data_csv
  16. self.data_root = data_root
  17. self.segmentation_root = segmentation_root
  18. with open(self.data_csv, "r") as f:
  19. self.image_paths = f.read().splitlines()
  20. self._length = len(self.image_paths)
  21. self.labels = {
  22. "relative_file_path_": [l for l in self.image_paths],
  23. "file_path_": [os.path.join(self.data_root, l)
  24. for l in self.image_paths],
  25. "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
  26. for l in self.image_paths]
  27. }
  28. size = None if size is not None and size<=0 else size
  29. self.size = size
  30. if self.size is not None:
  31. self.interpolation = interpolation
  32. self.interpolation = {
  33. "nearest": cv2.INTER_NEAREST,
  34. "bilinear": cv2.INTER_LINEAR,
  35. "bicubic": cv2.INTER_CUBIC,
  36. "area": cv2.INTER_AREA,
  37. "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
  38. self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
  39. interpolation=self.interpolation)
  40. self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
  41. interpolation=cv2.INTER_NEAREST)
  42. self.center_crop = not random_crop
  43. if self.center_crop:
  44. self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
  45. else:
  46. self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
  47. self.preprocessor = self.cropper
  48. def __len__(self):
  49. return self._length
  50. def __getitem__(self, i):
  51. example = dict((k, self.labels[k][i]) for k in self.labels)
  52. image = Image.open(example["file_path_"])
  53. if not image.mode == "RGB":
  54. image = image.convert("RGB")
  55. image = np.array(image).astype(np.uint8)
  56. if self.size is not None:
  57. image = self.image_rescaler(image=image)["image"]
  58. segmentation = Image.open(example["segmentation_path_"])
  59. assert segmentation.mode == "L", segmentation.mode
  60. segmentation = np.array(segmentation).astype(np.uint8)
  61. if self.shift_segmentation:
  62. # used to support segmentations containing unlabeled==255 label
  63. segmentation = segmentation+1
  64. if self.size is not None:
  65. segmentation = self.segmentation_rescaler(image=segmentation)["image"]
  66. if self.size is not None:
  67. processed = self.preprocessor(image=image,
  68. mask=segmentation
  69. )
  70. else:
  71. processed = {"image": image,
  72. "mask": segmentation
  73. }
  74. example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
  75. segmentation = processed["mask"]
  76. onehot = np.eye(self.n_labels)[segmentation]
  77. example["segmentation"] = onehot
  78. return example
  79. class Examples(SegmentationBase):
  80. def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
  81. super().__init__(data_csv="data/sflckr_examples.txt",
  82. data_root="data/sflckr_images",
  83. segmentation_root="data/sflckr_segmentations",
  84. size=size, random_crop=random_crop, interpolation=interpolation)