123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import math
- import random
- import warnings
- from itertools import cycle
- from typing import List, Optional, Tuple, Callable
- from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
- from more_itertools.recipes import grouper
- from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
- additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
- absolute_bbox, rescale_annotations
- from taming.data.helper_types import BoundingBox, Annotation
- from taming.data.image_transforms import convert_pil_to_tensor
- from torch import LongTensor, Tensor
- class ObjectsCenterPointsConditionalBuilder:
- def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
- use_group_parameter: bool, use_additional_parameters: bool):
- self.no_object_classes = no_object_classes
- self.no_max_objects = no_max_objects
- self.no_tokens = no_tokens
- self.encode_crop = encode_crop
- self.no_sections = int(math.sqrt(self.no_tokens))
- self.use_group_parameter = use_group_parameter
- self.use_additional_parameters = use_additional_parameters
- @property
- def none(self) -> int:
- return self.no_tokens - 1
- @property
- def object_descriptor_length(self) -> int:
- return 2
- @property
- def embedding_dim(self) -> int:
- extra_length = 2 if self.encode_crop else 0
- return self.no_max_objects * self.object_descriptor_length + extra_length
- def tokenize_coordinates(self, x: float, y: float) -> int:
- """
- Express 2d coordinates with one number.
- Example: assume self.no_tokens = 16, then no_sections = 4:
- 0 0 0 0
- 0 0 # 0
- 0 0 0 0
- 0 0 0 x
- Then the # position corresponds to token 6, the x position to token 15.
- @param x: float in [0, 1]
- @param y: float in [0, 1]
- @return: discrete tokenized coordinate
- """
- x_discrete = int(round(x * (self.no_sections - 1)))
- y_discrete = int(round(y * (self.no_sections - 1)))
- return y_discrete * self.no_sections + x_discrete
- def coordinates_from_token(self, token: int) -> (float, float):
- x = token % self.no_sections
- y = token // self.no_sections
- return x / (self.no_sections - 1), y / (self.no_sections - 1)
- def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
- x0, y0 = self.coordinates_from_token(token1)
- x1, y1 = self.coordinates_from_token(token2)
- return x0, y0, x1 - x0, y1 - y0
- def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
- return self.tokenize_coordinates(bbox[0], bbox[1]), \
- self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
- def inverse_build(self, conditional: LongTensor) \
- -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
- conditional_list = conditional.tolist()
- crop_coordinates = None
- if self.encode_crop:
- crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
- conditional_list = conditional_list[:-2]
- table_of_content = grouper(conditional_list, self.object_descriptor_length)
- assert conditional.shape[0] == self.embedding_dim
- return [
- (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
- for object_tuple in table_of_content if object_tuple[0] != self.none
- ], crop_coordinates
- def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
- line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
- plot = pil_image.new('RGB', figure_size, WHITE)
- draw = pil_img_draw.Draw(plot)
- circle_size = get_circle_size(figure_size)
- font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
- size=get_plot_font_size(font_size, figure_size))
- width, height = plot.size
- description, crop_coordinates = self.inverse_build(conditional)
- for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
- x_abs, y_abs = x * width, y * height
- ann = self.representation_to_annotation(representation)
- label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
- ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
- draw.ellipse(ellipse_bbox, fill=color, width=0)
- draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
- if crop_coordinates is not None:
- draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
- return convert_pil_to_tensor(plot) / 127.5 - 1.
- def object_representation(self, annotation: Annotation) -> int:
- modifier = 0
- if self.use_group_parameter:
- modifier |= 1 * (annotation.is_group_of is True)
- if self.use_additional_parameters:
- modifier |= 2 * (annotation.is_occluded is True)
- modifier |= 4 * (annotation.is_depiction is True)
- modifier |= 8 * (annotation.is_inside is True)
- return annotation.category_no + self.no_object_classes * modifier
- def representation_to_annotation(self, representation: int) -> Annotation:
- category_no = representation % self.no_object_classes
- modifier = representation // self.no_object_classes
- # noinspection PyTypeChecker
- return Annotation(
- area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
- category_no=category_no,
- is_group_of=bool((modifier & 1) * self.use_group_parameter),
- is_occluded=bool((modifier & 2) * self.use_additional_parameters),
- is_depiction=bool((modifier & 4) * self.use_additional_parameters),
- is_inside=bool((modifier & 8) * self.use_additional_parameters)
- )
- def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
- return list(self.token_pair_from_bbox(crop_coordinates))
- def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
- object_tuples = [
- (self.object_representation(a),
- self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
- for a in annotations
- ]
- empty_tuple = (self.none, self.none)
- object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
- return object_tuples
- def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
- -> LongTensor:
- if len(annotations) == 0:
- warnings.warn('Did not receive any annotations.')
- if len(annotations) > self.no_max_objects:
- warnings.warn('Received more annotations than allowed.')
- annotations = annotations[:self.no_max_objects]
- if not crop_coordinates:
- crop_coordinates = FULL_CROP
- random.shuffle(annotations)
- annotations = filter_annotations(annotations, crop_coordinates)
- if self.encode_crop:
- annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
- if horizontal_flip:
- crop_coordinates = horizontally_flip_bbox(crop_coordinates)
- extra = self._crop_encoder(crop_coordinates)
- else:
- annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
- extra = []
- object_tuples = self._make_object_descriptors(annotations)
- flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
- assert len(flattened) == self.embedding_dim
- assert all(0 <= value < self.no_tokens for value in flattened)
- return LongTensor(flattened)
|