coco.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. import json
  3. import albumentations
  4. import numpy as np
  5. from PIL import Image
  6. from tqdm import tqdm
  7. from torch.utils.data import Dataset
  8. from taming.data.sflckr import SegmentationBase # for examples included in repo
  9. class Examples(SegmentationBase):
  10. def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
  11. super().__init__(data_csv="data/coco_examples.txt",
  12. data_root="data/coco_images",
  13. segmentation_root="data/coco_segmentations",
  14. size=size, random_crop=random_crop,
  15. interpolation=interpolation,
  16. n_labels=183, shift_segmentation=True)
  17. class CocoBase(Dataset):
  18. """needed for (image, caption, segmentation) pairs"""
  19. def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
  20. crop_size=None, force_no_crop=False, given_files=None):
  21. self.split = self.get_split()
  22. self.size = size
  23. if crop_size is None:
  24. self.crop_size = size
  25. else:
  26. self.crop_size = crop_size
  27. self.onehot = onehot_segmentation # return segmentation as rgb or one hot
  28. self.stuffthing = use_stuffthing # include thing in segmentation
  29. if self.onehot and not self.stuffthing:
  30. raise NotImplemented("One hot mode is only supported for the "
  31. "stuffthings version because labels are stored "
  32. "a bit different.")
  33. data_json = datajson
  34. with open(data_json) as json_file:
  35. self.json_data = json.load(json_file)
  36. self.img_id_to_captions = dict()
  37. self.img_id_to_filepath = dict()
  38. self.img_id_to_segmentation_filepath = dict()
  39. assert data_json.split("/")[-1] in ["captions_train2017.json",
  40. "captions_val2017.json"]
  41. if self.stuffthing:
  42. self.segmentation_prefix = (
  43. "data/cocostuffthings/val2017" if
  44. data_json.endswith("captions_val2017.json") else
  45. "data/cocostuffthings/train2017")
  46. else:
  47. self.segmentation_prefix = (
  48. "data/coco/annotations/stuff_val2017_pixelmaps" if
  49. data_json.endswith("captions_val2017.json") else
  50. "data/coco/annotations/stuff_train2017_pixelmaps")
  51. imagedirs = self.json_data["images"]
  52. self.labels = {"image_ids": list()}
  53. for imgdir in tqdm(imagedirs, desc="ImgToPath"):
  54. self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
  55. self.img_id_to_captions[imgdir["id"]] = list()
  56. pngfilename = imgdir["file_name"].replace("jpg", "png")
  57. self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
  58. self.segmentation_prefix, pngfilename)
  59. if given_files is not None:
  60. if pngfilename in given_files:
  61. self.labels["image_ids"].append(imgdir["id"])
  62. else:
  63. self.labels["image_ids"].append(imgdir["id"])
  64. capdirs = self.json_data["annotations"]
  65. for capdir in tqdm(capdirs, desc="ImgToCaptions"):
  66. # there are in average 5 captions per image
  67. self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
  68. self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
  69. if self.split=="validation":
  70. self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
  71. else:
  72. self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
  73. self.preprocessor = albumentations.Compose(
  74. [self.rescaler, self.cropper],
  75. additional_targets={"segmentation": "image"})
  76. if force_no_crop:
  77. self.rescaler = albumentations.Resize(height=self.size, width=self.size)
  78. self.preprocessor = albumentations.Compose(
  79. [self.rescaler],
  80. additional_targets={"segmentation": "image"})
  81. def __len__(self):
  82. return len(self.labels["image_ids"])
  83. def preprocess_image(self, image_path, segmentation_path):
  84. image = Image.open(image_path)
  85. if not image.mode == "RGB":
  86. image = image.convert("RGB")
  87. image = np.array(image).astype(np.uint8)
  88. segmentation = Image.open(segmentation_path)
  89. if not self.onehot and not segmentation.mode == "RGB":
  90. segmentation = segmentation.convert("RGB")
  91. segmentation = np.array(segmentation).astype(np.uint8)
  92. if self.onehot:
  93. assert self.stuffthing
  94. # stored in caffe format: unlabeled==255. stuff and thing from
  95. # 0-181. to be compatible with the labels in
  96. # https://github.com/nightrome/cocostuff/blob/master/labels.txt
  97. # we shift stuffthing one to the right and put unlabeled in zero
  98. # as long as segmentation is uint8 shifting to right handles the
  99. # latter too
  100. assert segmentation.dtype == np.uint8
  101. segmentation = segmentation + 1
  102. processed = self.preprocessor(image=image, segmentation=segmentation)
  103. image, segmentation = processed["image"], processed["segmentation"]
  104. image = (image / 127.5 - 1.0).astype(np.float32)
  105. if self.onehot:
  106. assert segmentation.dtype == np.uint8
  107. # make it one hot
  108. n_labels = 183
  109. flatseg = np.ravel(segmentation)
  110. onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
  111. onehot[np.arange(flatseg.size), flatseg] = True
  112. onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
  113. segmentation = onehot
  114. else:
  115. segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
  116. return image, segmentation
  117. def __getitem__(self, i):
  118. img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
  119. seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
  120. image, segmentation = self.preprocess_image(img_path, seg_path)
  121. captions = self.img_id_to_captions[self.labels["image_ids"][i]]
  122. # randomly draw one of all available captions per image
  123. caption = captions[np.random.randint(0, len(captions))]
  124. example = {"image": image,
  125. "caption": [str(caption[0])],
  126. "segmentation": segmentation,
  127. "img_path": img_path,
  128. "seg_path": seg_path,
  129. "filename_": img_path.split(os.sep)[-1]
  130. }
  131. return example
  132. class CocoImagesAndCaptionsTrain(CocoBase):
  133. """returns a pair of (image, caption)"""
  134. def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
  135. super().__init__(size=size,
  136. dataroot="data/coco/train2017",
  137. datajson="data/coco/annotations/captions_train2017.json",
  138. onehot_segmentation=onehot_segmentation,
  139. use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
  140. def get_split(self):
  141. return "train"
  142. class CocoImagesAndCaptionsValidation(CocoBase):
  143. """returns a pair of (image, caption)"""
  144. def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
  145. given_files=None):
  146. super().__init__(size=size,
  147. dataroot="data/coco/val2017",
  148. datajson="data/coco/annotations/captions_val2017.json",
  149. onehot_segmentation=onehot_segmentation,
  150. use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
  151. given_files=given_files)
  152. def get_split(self):
  153. return "validation"