utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import importlib
  2. from typing import List, Any, Tuple, Optional
  3. from taming.data.helper_types import BoundingBox, Annotation
  4. # source: seaborn, color palette tab10
  5. COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
  6. (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
  7. BLACK = (0, 0, 0)
  8. GRAY_75 = (63, 63, 63)
  9. GRAY_50 = (127, 127, 127)
  10. GRAY_25 = (191, 191, 191)
  11. WHITE = (255, 255, 255)
  12. FULL_CROP = (0., 0., 1., 1.)
  13. def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
  14. """
  15. Give intersection area of two rectangles.
  16. @param rectangle1: (x0, y0, w, h) of first rectangle
  17. @param rectangle2: (x0, y0, w, h) of second rectangle
  18. """
  19. rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
  20. rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
  21. x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
  22. y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
  23. return x_overlap * y_overlap
  24. def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
  25. return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
  26. def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
  27. bbox = relative_bbox
  28. bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
  29. return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
  30. def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
  31. return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
  32. def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
  33. List[Annotation]:
  34. def clamp(x: float):
  35. return max(min(x, 1.), 0.)
  36. def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
  37. x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
  38. y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
  39. w = min(bbox[2] / crop_coordinates[2], 1 - x0)
  40. h = min(bbox[3] / crop_coordinates[3], 1 - y0)
  41. if flip:
  42. x0 = 1 - (x0 + w)
  43. return x0, y0, w, h
  44. return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
  45. def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
  46. return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
  47. def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
  48. sl = slice(1) if short else slice(None)
  49. string = ''
  50. if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
  51. return string
  52. if annotation.is_group_of:
  53. string += 'group'[sl] + ','
  54. if annotation.is_occluded:
  55. string += 'occluded'[sl] + ','
  56. if annotation.is_depiction:
  57. string += 'depiction'[sl] + ','
  58. if annotation.is_inside:
  59. string += 'inside'[sl]
  60. return '(' + string.strip(",") + ')'
  61. def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
  62. if font_size is None:
  63. font_size = 10
  64. if max(figure_size) >= 256:
  65. font_size = 12
  66. if max(figure_size) >= 512:
  67. font_size = 15
  68. return font_size
  69. def get_circle_size(figure_size: Tuple[int, int]) -> int:
  70. circle_size = 2
  71. if max(figure_size) >= 256:
  72. circle_size = 3
  73. if max(figure_size) >= 512:
  74. circle_size = 4
  75. return circle_size
  76. def load_object_from_string(object_string: str) -> Any:
  77. """
  78. Source: https://stackoverflow.com/a/10773699
  79. """
  80. module_name, class_name = object_string.rsplit(".", 1)
  81. return getattr(importlib.import_module(module_name), class_name)