objects_bbox.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from itertools import cycle
  2. from typing import List, Tuple, Callable, Optional
  3. from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
  4. from more_itertools.recipes import grouper
  5. from taming.data.image_transforms import convert_pil_to_tensor
  6. from torch import LongTensor, Tensor
  7. from taming.data.helper_types import BoundingBox, Annotation
  8. from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
  9. from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
  10. pad_list, get_plot_font_size, absolute_bbox
  11. class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
  12. @property
  13. def object_descriptor_length(self) -> int:
  14. return 3
  15. def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
  16. object_triples = [
  17. (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
  18. for ann in annotations
  19. ]
  20. empty_triple = (self.none, self.none, self.none)
  21. object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
  22. return object_triples
  23. def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
  24. conditional_list = conditional.tolist()
  25. crop_coordinates = None
  26. if self.encode_crop:
  27. crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
  28. conditional_list = conditional_list[:-2]
  29. object_triples = grouper(conditional_list, 3)
  30. assert conditional.shape[0] == self.embedding_dim
  31. return [
  32. (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
  33. for object_triple in object_triples if object_triple[0] != self.none
  34. ], crop_coordinates
  35. def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
  36. line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
  37. plot = pil_image.new('RGB', figure_size, WHITE)
  38. draw = pil_img_draw.Draw(plot)
  39. font = ImageFont.truetype(
  40. "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
  41. size=get_plot_font_size(font_size, figure_size)
  42. )
  43. width, height = plot.size
  44. description, crop_coordinates = self.inverse_build(conditional)
  45. for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
  46. annotation = self.representation_to_annotation(representation)
  47. class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
  48. bbox = absolute_bbox(bbox, width, height)
  49. draw.rectangle(bbox, outline=color, width=line_width)
  50. draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
  51. if crop_coordinates is not None:
  52. draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
  53. return convert_pil_to_tensor(plot) / 127.5 - 1.