2
0

image_transforms.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import random
  2. import warnings
  3. from typing import Union
  4. import torch
  5. from torch import Tensor
  6. from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
  7. from torchvision.transforms.functional import _get_image_size as get_image_size
  8. from taming.data.helper_types import BoundingBox, Image
  9. pil_to_tensor = PILToTensor()
  10. def convert_pil_to_tensor(image: Image) -> Tensor:
  11. with warnings.catch_warnings():
  12. # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
  13. warnings.simplefilter("ignore")
  14. return pil_to_tensor(image)
  15. class RandomCrop1dReturnCoordinates(RandomCrop):
  16. def forward(self, img: Image) -> (BoundingBox, Image):
  17. """
  18. Additionally to cropping, returns the relative coordinates of the crop bounding box.
  19. Args:
  20. img (PIL Image or Tensor): Image to be cropped.
  21. Returns:
  22. Bounding box: x0, y0, w, h
  23. PIL Image or Tensor: Cropped image.
  24. Based on:
  25. torchvision.transforms.RandomCrop, torchvision 1.7.0
  26. """
  27. if self.padding is not None:
  28. img = F.pad(img, self.padding, self.fill, self.padding_mode)
  29. width, height = get_image_size(img)
  30. # pad the width if needed
  31. if self.pad_if_needed and width < self.size[1]:
  32. padding = [self.size[1] - width, 0]
  33. img = F.pad(img, padding, self.fill, self.padding_mode)
  34. # pad the height if needed
  35. if self.pad_if_needed and height < self.size[0]:
  36. padding = [0, self.size[0] - height]
  37. img = F.pad(img, padding, self.fill, self.padding_mode)
  38. i, j, h, w = self.get_params(img, self.size)
  39. bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
  40. return bbox, F.crop(img, i, j, h, w)
  41. class Random2dCropReturnCoordinates(torch.nn.Module):
  42. """
  43. Additionally to cropping, returns the relative coordinates of the crop bounding box.
  44. Args:
  45. img (PIL Image or Tensor): Image to be cropped.
  46. Returns:
  47. Bounding box: x0, y0, w, h
  48. PIL Image or Tensor: Cropped image.
  49. Based on:
  50. torchvision.transforms.RandomCrop, torchvision 1.7.0
  51. """
  52. def __init__(self, min_size: int):
  53. super().__init__()
  54. self.min_size = min_size
  55. def forward(self, img: Image) -> (BoundingBox, Image):
  56. width, height = get_image_size(img)
  57. max_size = min(width, height)
  58. if max_size <= self.min_size:
  59. size = max_size
  60. else:
  61. size = random.randint(self.min_size, max_size)
  62. top = random.randint(0, height - size)
  63. left = random.randint(0, width - size)
  64. bbox = left / width, top / height, size / width, size / height
  65. return bbox, F.crop(img, top, left, size, size)
  66. class CenterCropReturnCoordinates(CenterCrop):
  67. @staticmethod
  68. def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
  69. if width > height:
  70. w = height / width
  71. h = 1.0
  72. x0 = 0.5 - w / 2
  73. y0 = 0.
  74. else:
  75. w = 1.0
  76. h = width / height
  77. x0 = 0.
  78. y0 = 0.5 - h / 2
  79. return x0, y0, w, h
  80. def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
  81. """
  82. Additionally to cropping, returns the relative coordinates of the crop bounding box.
  83. Args:
  84. img (PIL Image or Tensor): Image to be cropped.
  85. Returns:
  86. Bounding box: x0, y0, w, h
  87. PIL Image or Tensor: Cropped image.
  88. Based on:
  89. torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
  90. """
  91. width, height = get_image_size(img)
  92. return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
  93. class RandomHorizontalFlipReturn(RandomHorizontalFlip):
  94. def forward(self, img: Image) -> (bool, Image):
  95. """
  96. Additionally to flipping, returns a boolean whether it was flipped or not.
  97. Args:
  98. img (PIL Image or Tensor): Image to be flipped.
  99. Returns:
  100. flipped: whether the image was flipped or not
  101. PIL Image or Tensor: Randomly flipped image.
  102. Based on:
  103. torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
  104. """
  105. if torch.rand(1) < self.p:
  106. return True, F.hflip(img)
  107. return False, img