objects_center_points.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import math
  2. import random
  3. import warnings
  4. from itertools import cycle
  5. from typing import List, Optional, Tuple, Callable
  6. from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
  7. from more_itertools.recipes import grouper
  8. from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
  9. additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
  10. absolute_bbox, rescale_annotations
  11. from taming.data.helper_types import BoundingBox, Annotation
  12. from taming.data.image_transforms import convert_pil_to_tensor
  13. from torch import LongTensor, Tensor
  14. class ObjectsCenterPointsConditionalBuilder:
  15. def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
  16. use_group_parameter: bool, use_additional_parameters: bool):
  17. self.no_object_classes = no_object_classes
  18. self.no_max_objects = no_max_objects
  19. self.no_tokens = no_tokens
  20. self.encode_crop = encode_crop
  21. self.no_sections = int(math.sqrt(self.no_tokens))
  22. self.use_group_parameter = use_group_parameter
  23. self.use_additional_parameters = use_additional_parameters
  24. @property
  25. def none(self) -> int:
  26. return self.no_tokens - 1
  27. @property
  28. def object_descriptor_length(self) -> int:
  29. return 2
  30. @property
  31. def embedding_dim(self) -> int:
  32. extra_length = 2 if self.encode_crop else 0
  33. return self.no_max_objects * self.object_descriptor_length + extra_length
  34. def tokenize_coordinates(self, x: float, y: float) -> int:
  35. """
  36. Express 2d coordinates with one number.
  37. Example: assume self.no_tokens = 16, then no_sections = 4:
  38. 0 0 0 0
  39. 0 0 # 0
  40. 0 0 0 0
  41. 0 0 0 x
  42. Then the # position corresponds to token 6, the x position to token 15.
  43. @param x: float in [0, 1]
  44. @param y: float in [0, 1]
  45. @return: discrete tokenized coordinate
  46. """
  47. x_discrete = int(round(x * (self.no_sections - 1)))
  48. y_discrete = int(round(y * (self.no_sections - 1)))
  49. return y_discrete * self.no_sections + x_discrete
  50. def coordinates_from_token(self, token: int) -> (float, float):
  51. x = token % self.no_sections
  52. y = token // self.no_sections
  53. return x / (self.no_sections - 1), y / (self.no_sections - 1)
  54. def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
  55. x0, y0 = self.coordinates_from_token(token1)
  56. x1, y1 = self.coordinates_from_token(token2)
  57. return x0, y0, x1 - x0, y1 - y0
  58. def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
  59. return self.tokenize_coordinates(bbox[0], bbox[1]), \
  60. self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
  61. def inverse_build(self, conditional: LongTensor) \
  62. -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
  63. conditional_list = conditional.tolist()
  64. crop_coordinates = None
  65. if self.encode_crop:
  66. crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
  67. conditional_list = conditional_list[:-2]
  68. table_of_content = grouper(conditional_list, self.object_descriptor_length)
  69. assert conditional.shape[0] == self.embedding_dim
  70. return [
  71. (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
  72. for object_tuple in table_of_content if object_tuple[0] != self.none
  73. ], crop_coordinates
  74. def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
  75. line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
  76. plot = pil_image.new('RGB', figure_size, WHITE)
  77. draw = pil_img_draw.Draw(plot)
  78. circle_size = get_circle_size(figure_size)
  79. font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
  80. size=get_plot_font_size(font_size, figure_size))
  81. width, height = plot.size
  82. description, crop_coordinates = self.inverse_build(conditional)
  83. for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
  84. x_abs, y_abs = x * width, y * height
  85. ann = self.representation_to_annotation(representation)
  86. label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
  87. ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
  88. draw.ellipse(ellipse_bbox, fill=color, width=0)
  89. draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
  90. if crop_coordinates is not None:
  91. draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
  92. return convert_pil_to_tensor(plot) / 127.5 - 1.
  93. def object_representation(self, annotation: Annotation) -> int:
  94. modifier = 0
  95. if self.use_group_parameter:
  96. modifier |= 1 * (annotation.is_group_of is True)
  97. if self.use_additional_parameters:
  98. modifier |= 2 * (annotation.is_occluded is True)
  99. modifier |= 4 * (annotation.is_depiction is True)
  100. modifier |= 8 * (annotation.is_inside is True)
  101. return annotation.category_no + self.no_object_classes * modifier
  102. def representation_to_annotation(self, representation: int) -> Annotation:
  103. category_no = representation % self.no_object_classes
  104. modifier = representation // self.no_object_classes
  105. # noinspection PyTypeChecker
  106. return Annotation(
  107. area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
  108. category_no=category_no,
  109. is_group_of=bool((modifier & 1) * self.use_group_parameter),
  110. is_occluded=bool((modifier & 2) * self.use_additional_parameters),
  111. is_depiction=bool((modifier & 4) * self.use_additional_parameters),
  112. is_inside=bool((modifier & 8) * self.use_additional_parameters)
  113. )
  114. def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
  115. return list(self.token_pair_from_bbox(crop_coordinates))
  116. def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
  117. object_tuples = [
  118. (self.object_representation(a),
  119. self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
  120. for a in annotations
  121. ]
  122. empty_tuple = (self.none, self.none)
  123. object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
  124. return object_tuples
  125. def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
  126. -> LongTensor:
  127. if len(annotations) == 0:
  128. warnings.warn('Did not receive any annotations.')
  129. if len(annotations) > self.no_max_objects:
  130. warnings.warn('Received more annotations than allowed.')
  131. annotations = annotations[:self.no_max_objects]
  132. if not crop_coordinates:
  133. crop_coordinates = FULL_CROP
  134. random.shuffle(annotations)
  135. annotations = filter_annotations(annotations, crop_coordinates)
  136. if self.encode_crop:
  137. annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
  138. if horizontal_flip:
  139. crop_coordinates = horizontally_flip_bbox(crop_coordinates)
  140. extra = self._crop_encoder(crop_coordinates)
  141. else:
  142. annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
  143. extra = []
  144. object_tuples = self._make_object_descriptors(annotations)
  145. flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
  146. assert len(flattened) == self.embedding_dim
  147. assert all(0 <= value < self.no_tokens for value in flattened)
  148. return LongTensor(flattened)