annotated_objects_dataset.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from pathlib import Path
  2. from typing import Optional, List, Callable, Dict, Any, Union
  3. import warnings
  4. import PIL.Image as pil_image
  5. from torch import Tensor
  6. from torch.utils.data import Dataset
  7. from torchvision import transforms
  8. from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
  9. from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
  10. from taming.data.conditional_builder.utils import load_object_from_string
  11. from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
  12. from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
  13. Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
  14. class AnnotatedObjectsDataset(Dataset):
  15. def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
  16. min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
  17. crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
  18. encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
  19. no_object_classes: Optional[int] = None):
  20. self.data_path = data_path
  21. self.split = split
  22. self.keys = keys
  23. self.target_image_size = target_image_size
  24. self.min_object_area = min_object_area
  25. self.min_objects_per_image = min_objects_per_image
  26. self.max_objects_per_image = max_objects_per_image
  27. self.crop_method = crop_method
  28. self.random_flip = random_flip
  29. self.no_tokens = no_tokens
  30. self.use_group_parameter = use_group_parameter
  31. self.encode_crop = encode_crop
  32. self.annotations = None
  33. self.image_descriptions = None
  34. self.categories = None
  35. self.category_ids = None
  36. self.category_number = None
  37. self.image_ids = None
  38. self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
  39. self.paths = self.build_paths(self.data_path)
  40. self._conditional_builders = None
  41. self.category_allow_list = None
  42. if category_allow_list_target:
  43. allow_list = load_object_from_string(category_allow_list_target)
  44. self.category_allow_list = {name for name, _ in allow_list}
  45. self.category_mapping = {}
  46. if category_mapping_target:
  47. self.category_mapping = load_object_from_string(category_mapping_target)
  48. self.no_object_classes = no_object_classes
  49. def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
  50. top_level = Path(top_level)
  51. sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
  52. for path in sub_paths.values():
  53. if not path.exists():
  54. raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
  55. return sub_paths
  56. @staticmethod
  57. def load_image_from_disk(path: Path) -> Image:
  58. return pil_image.open(path).convert('RGB')
  59. @staticmethod
  60. def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
  61. transform_functions = []
  62. if crop_method == 'none':
  63. transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
  64. elif crop_method == 'center':
  65. transform_functions.extend([
  66. transforms.Resize(target_image_size),
  67. CenterCropReturnCoordinates(target_image_size)
  68. ])
  69. elif crop_method == 'random-1d':
  70. transform_functions.extend([
  71. transforms.Resize(target_image_size),
  72. RandomCrop1dReturnCoordinates(target_image_size)
  73. ])
  74. elif crop_method == 'random-2d':
  75. transform_functions.extend([
  76. Random2dCropReturnCoordinates(target_image_size),
  77. transforms.Resize(target_image_size)
  78. ])
  79. elif crop_method is None:
  80. return None
  81. else:
  82. raise ValueError(f'Received invalid crop method [{crop_method}].')
  83. if random_flip:
  84. transform_functions.append(RandomHorizontalFlipReturn())
  85. transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
  86. return transform_functions
  87. def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
  88. crop_bbox = None
  89. flipped = None
  90. for t in self.transform_functions:
  91. if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
  92. crop_bbox, x = t(x)
  93. elif isinstance(t, RandomHorizontalFlipReturn):
  94. flipped, x = t(x)
  95. else:
  96. x = t(x)
  97. return crop_bbox, flipped, x
  98. @property
  99. def no_classes(self) -> int:
  100. return self.no_object_classes if self.no_object_classes else len(self.categories)
  101. @property
  102. def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
  103. # cannot set this up in init because no_classes is only known after loading data in init of superclass
  104. if self._conditional_builders is None:
  105. self._conditional_builders = {
  106. 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
  107. self.no_classes,
  108. self.max_objects_per_image,
  109. self.no_tokens,
  110. self.encode_crop,
  111. self.use_group_parameter,
  112. getattr(self, 'use_additional_parameters', False)
  113. ),
  114. 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
  115. self.no_classes,
  116. self.max_objects_per_image,
  117. self.no_tokens,
  118. self.encode_crop,
  119. self.use_group_parameter,
  120. getattr(self, 'use_additional_parameters', False)
  121. )
  122. }
  123. return self._conditional_builders
  124. def filter_categories(self) -> None:
  125. if self.category_allow_list:
  126. self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
  127. if self.category_mapping:
  128. self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
  129. def setup_category_id_and_number(self) -> None:
  130. self.category_ids = list(self.categories.keys())
  131. self.category_ids.sort()
  132. if '/m/01s55n' in self.category_ids:
  133. self.category_ids.remove('/m/01s55n')
  134. self.category_ids.append('/m/01s55n')
  135. self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
  136. if self.category_allow_list is not None and self.category_mapping is None \
  137. and len(self.category_ids) != len(self.category_allow_list):
  138. warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
  139. 'Make sure all names in category_allow_list exist.')
  140. def clean_up_annotations_and_image_descriptions(self) -> None:
  141. image_id_set = set(self.image_ids)
  142. self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
  143. self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
  144. @staticmethod
  145. def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
  146. min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
  147. filtered = {}
  148. for image_id, annotations in all_annotations.items():
  149. annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
  150. if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
  151. filtered[image_id] = annotations_with_min_area
  152. return filtered
  153. def __len__(self):
  154. return len(self.image_ids)
  155. def __getitem__(self, n: int) -> Dict[str, Any]:
  156. image_id = self.get_image_id(n)
  157. sample = self.get_image_description(image_id)
  158. sample['annotations'] = self.get_annotation(image_id)
  159. if 'image' in self.keys:
  160. sample['image_path'] = str(self.get_image_path(image_id))
  161. sample['image'] = self.load_image_from_disk(sample['image_path'])
  162. sample['image'] = convert_pil_to_tensor(sample['image'])
  163. sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
  164. sample['image'] = sample['image'].permute(1, 2, 0)
  165. for conditional, builder in self.conditional_builders.items():
  166. if conditional in self.keys:
  167. sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
  168. if self.keys:
  169. # only return specified keys
  170. sample = {key: sample[key] for key in self.keys}
  171. return sample
  172. def get_image_id(self, no: int) -> str:
  173. return self.image_ids[no]
  174. def get_annotation(self, image_id: str) -> str:
  175. return self.annotations[image_id]
  176. def get_textual_label_for_category_id(self, category_id: str) -> str:
  177. return self.categories[category_id].name
  178. def get_textual_label_for_category_no(self, category_no: int) -> str:
  179. return self.categories[self.get_category_id(category_no)].name
  180. def get_category_number(self, category_id: str) -> int:
  181. return self.category_number[category_id]
  182. def get_category_id(self, category_no: int) -> str:
  183. return self.category_ids[category_no]
  184. def get_image_description(self, image_id: str) -> Dict[str, Any]:
  185. raise NotImplementedError()
  186. def get_path_structure(self):
  187. raise NotImplementedError
  188. def get_image_path(self, image_id: str) -> Path:
  189. raise NotImplementedError