annotated_objects_coco.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import json
  2. from itertools import chain
  3. from pathlib import Path
  4. from typing import Iterable, Dict, List, Callable, Any
  5. from collections import defaultdict
  6. from tqdm import tqdm
  7. from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
  8. from taming.data.helper_types import Annotation, ImageDescription, Category
  9. COCO_PATH_STRUCTURE = {
  10. 'train': {
  11. 'top_level': '',
  12. 'instances_annotations': 'annotations/instances_train2017.json',
  13. 'stuff_annotations': 'annotations/stuff_train2017.json',
  14. 'files': 'train2017'
  15. },
  16. 'validation': {
  17. 'top_level': '',
  18. 'instances_annotations': 'annotations/instances_val2017.json',
  19. 'stuff_annotations': 'annotations/stuff_val2017.json',
  20. 'files': 'val2017'
  21. }
  22. }
  23. def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
  24. return {
  25. str(img['id']): ImageDescription(
  26. id=img['id'],
  27. license=img.get('license'),
  28. file_name=img['file_name'],
  29. coco_url=img['coco_url'],
  30. original_size=(img['width'], img['height']),
  31. date_captured=img.get('date_captured'),
  32. flickr_url=img.get('flickr_url')
  33. )
  34. for img in description_json
  35. }
  36. def load_categories(category_json: Iterable) -> Dict[str, Category]:
  37. return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
  38. for cat in category_json if cat['name'] != 'other'}
  39. def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
  40. category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
  41. annotations = defaultdict(list)
  42. total = sum(len(a) for a in annotations_json)
  43. for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
  44. image_id = str(ann['image_id'])
  45. if image_id not in image_descriptions:
  46. raise ValueError(f'image_id [{image_id}] has no image description.')
  47. category_id = ann['category_id']
  48. try:
  49. category_no = category_no_for_id(str(category_id))
  50. except KeyError:
  51. continue
  52. width, height = image_descriptions[image_id].original_size
  53. bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
  54. annotations[image_id].append(
  55. Annotation(
  56. id=ann['id'],
  57. area=bbox[2]*bbox[3], # use bbox area
  58. is_group_of=ann['iscrowd'],
  59. image_id=ann['image_id'],
  60. bbox=bbox,
  61. category_id=str(category_id),
  62. category_no=category_no
  63. )
  64. )
  65. return dict(annotations)
  66. class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
  67. def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
  68. """
  69. @param data_path: is the path to the following folder structure:
  70. coco/
  71. ├── annotations
  72. │ ├── instances_train2017.json
  73. │ ├── instances_val2017.json
  74. │ ├── stuff_train2017.json
  75. │ └── stuff_val2017.json
  76. ├── train2017
  77. │ ├── 000000000009.jpg
  78. │ ├── 000000000025.jpg
  79. │ └── ...
  80. ├── val2017
  81. │ ├── 000000000139.jpg
  82. │ ├── 000000000285.jpg
  83. │ └── ...
  84. @param: split: one of 'train' or 'validation'
  85. @param: desired image size (give square images)
  86. """
  87. super().__init__(**kwargs)
  88. self.use_things = use_things
  89. self.use_stuff = use_stuff
  90. with open(self.paths['instances_annotations']) as f:
  91. inst_data_json = json.load(f)
  92. with open(self.paths['stuff_annotations']) as f:
  93. stuff_data_json = json.load(f)
  94. category_jsons = []
  95. annotation_jsons = []
  96. if self.use_things:
  97. category_jsons.append(inst_data_json['categories'])
  98. annotation_jsons.append(inst_data_json['annotations'])
  99. if self.use_stuff:
  100. category_jsons.append(stuff_data_json['categories'])
  101. annotation_jsons.append(stuff_data_json['annotations'])
  102. self.categories = load_categories(chain(*category_jsons))
  103. self.filter_categories()
  104. self.setup_category_id_and_number()
  105. self.image_descriptions = load_image_descriptions(inst_data_json['images'])
  106. annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
  107. self.annotations = self.filter_object_number(annotations, self.min_object_area,
  108. self.min_objects_per_image, self.max_objects_per_image)
  109. self.image_ids = list(self.annotations.keys())
  110. self.clean_up_annotations_and_image_descriptions()
  111. def get_path_structure(self) -> Dict[str, str]:
  112. if self.split not in COCO_PATH_STRUCTURE:
  113. raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
  114. return COCO_PATH_STRUCTURE[self.split]
  115. def get_image_path(self, image_id: str) -> Path:
  116. return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
  117. def get_image_description(self, image_id: str) -> Dict[str, Any]:
  118. # noinspection PyProtectedMember
  119. return self.image_descriptions[image_id]._asdict()