utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import base64
  2. import os
  3. from io import BytesIO
  4. from typing import Optional, Union
  5. from urllib.parse import urlparse
  6. import aiohttp
  7. import requests
  8. from PIL import Image
  9. from aphrodite.multimodal.base import MultiModalDataDict
  10. from aphrodite.version import __version__ as APHRODITE_VERSION
  11. APHRODITE_IMAGE_FETCH_TIMEOUT = int(
  12. os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
  13. def _validate_remote_url(url: str, *, name: str):
  14. parsed_url = urlparse(url)
  15. if parsed_url.scheme not in ["http", "https"]:
  16. raise ValueError(f"Invalid '{name}': A valid '{name}' "
  17. "must have scheme 'http' or 'https'.")
  18. def _get_request_headers():
  19. return {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
  20. def _load_image_from_bytes(b: bytes):
  21. image = Image.open(BytesIO(b))
  22. image.load()
  23. return image
  24. def _load_image_from_data_url(image_url: str):
  25. # Only split once and assume the second part is the base64 encoded image
  26. _, image_base64 = image_url.split(",", 1)
  27. return load_image_from_base64(image_base64)
  28. def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
  29. """
  30. Load a PIL image from a HTTP or base64 data URL.
  31. By default, the image is converted into RGB format.
  32. """
  33. if image_url.startswith('http'):
  34. _validate_remote_url(image_url, name="image_url")
  35. headers = _get_request_headers()
  36. with requests.get(url=image_url, headers=headers) as response:
  37. response.raise_for_status()
  38. image_raw = response.content
  39. image = _load_image_from_bytes(image_raw)
  40. elif image_url.startswith('data:image'):
  41. image = _load_image_from_data_url(image_url)
  42. else:
  43. raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
  44. "with either 'data:image' or 'http'.")
  45. return image.convert(image_mode)
  46. class ImageFetchAiohttp:
  47. aiohttp_client: Optional[aiohttp.ClientSession] = None
  48. @classmethod
  49. def get_aiohttp_client(cls) -> aiohttp.ClientSession:
  50. if cls.aiohttp_client is None:
  51. timeout = aiohttp.ClientTimeout(
  52. total=APHRODITE_IMAGE_FETCH_TIMEOUT)
  53. connector = aiohttp.TCPConnector()
  54. cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
  55. connector=connector)
  56. return cls.aiohttp_client
  57. @classmethod
  58. async def fetch_image(
  59. cls,
  60. image_url: str,
  61. *,
  62. image_mode: str = "RGB",
  63. ) -> Image.Image:
  64. """
  65. Asynchronously load a PIL image from a HTTP or base64 data URL.
  66. By default, the image is converted into RGB format.
  67. """
  68. if image_url.startswith('http'):
  69. _validate_remote_url(image_url, name="image_url")
  70. client = cls.get_aiohttp_client()
  71. headers = _get_request_headers()
  72. async with client.get(url=image_url, headers=headers) as response:
  73. response.raise_for_status()
  74. image_raw = await response.read()
  75. image = _load_image_from_bytes(image_raw)
  76. # Only split once and assume the second part is the base64 encoded image
  77. elif image_url.startswith('data:image'):
  78. image = _load_image_from_data_url(image_url)
  79. else:
  80. raise ValueError(
  81. "Invalid 'image_url': A valid 'image_url' must start "
  82. "with either 'data:image' or 'http'.")
  83. return image.convert(image_mode)
  84. async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
  85. image = await ImageFetchAiohttp.fetch_image(image_url)
  86. return {"image": image}
  87. def encode_image_base64(
  88. image: Image.Image,
  89. *,
  90. image_mode: str = "RGB",
  91. format: str = "JPEG",
  92. ) -> str:
  93. """
  94. Encode a pillow image to base64 format.
  95. By default, the image is converted into RGB format before being encoded.
  96. """
  97. buffered = BytesIO()
  98. image = image.convert(image_mode)
  99. image.save(buffered, format)
  100. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  101. def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
  102. """Load image from base64 format."""
  103. return _load_image_from_bytes(base64.b64decode(image))
  104. def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
  105. """Rescale the dimensions of an image by a constant factor."""
  106. new_width = int(image.width * size_factor)
  107. new_height = int(image.height * size_factor)
  108. return image.resize((new_width, new_height))