utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import base64
  2. import os
  3. from io import BytesIO
  4. from typing import Optional, Union
  5. from urllib.parse import urlparse
  6. import aiohttp
  7. import requests
  8. from PIL import Image
  9. from aphrodite.multimodal.base import MultiModalDataDict
  10. from aphrodite.version import __version__ as APHRODITE_VERSION
  11. APHRODITE_IMAGE_FETCH_TIMEOUT = int(
  12. os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
  13. def _validate_remote_url(url: str, *, name: str):
  14. parsed_url = urlparse(url)
  15. if parsed_url.scheme not in ["http", "https"]:
  16. raise ValueError(f"Invalid '{name}': A valid '{name}' "
  17. "must have scheme 'http' or 'https'.")
  18. def _get_request_headers():
  19. return {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
  20. def _load_image_from_bytes(b: bytes):
  21. image = Image.open(BytesIO(b))
  22. image.load()
  23. return image
  24. def _load_image_from_data_url(image_url: str):
  25. # Only split once and assume the second part is the base64 encoded image
  26. _, image_base64 = image_url.split(",", 1)
  27. return load_image_from_base64(image_base64)
  28. def fetch_image(image_url: str) -> Image.Image:
  29. """Load PIL image from a url or base64 encoded openai GPT4V format"""
  30. if image_url.startswith('http'):
  31. _validate_remote_url(image_url, name="image_url")
  32. headers = _get_request_headers()
  33. with requests.get(url=image_url, headers=headers) as response:
  34. response.raise_for_status()
  35. image_raw = response.content
  36. image = _load_image_from_bytes(image_raw)
  37. elif image_url.startswith('data:image'):
  38. image = _load_image_from_data_url(image_url)
  39. else:
  40. raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
  41. "with either 'data:image' or 'http'.")
  42. return image
  43. class ImageFetchAiohttp:
  44. aiohttp_client: Optional[aiohttp.ClientSession] = None
  45. @classmethod
  46. def get_aiohttp_client(cls) -> aiohttp.ClientSession:
  47. if cls.aiohttp_client is None:
  48. timeout = aiohttp.ClientTimeout(
  49. total=APHRODITE_IMAGE_FETCH_TIMEOUT)
  50. connector = aiohttp.TCPConnector()
  51. cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
  52. connector=connector)
  53. return cls.aiohttp_client
  54. @classmethod
  55. async def fetch_image(cls, image_url: str) -> Image.Image:
  56. """Load PIL image from a url or base64 encoded openai GPT4V format"""
  57. if image_url.startswith('http'):
  58. _validate_remote_url(image_url, name="image_url")
  59. client = cls.get_aiohttp_client()
  60. headers = _get_request_headers()
  61. async with client.get(url=image_url, headers=headers) as response:
  62. response.raise_for_status()
  63. image_raw = await response.read()
  64. image = _load_image_from_bytes(image_raw)
  65. # Only split once and assume the second part is the base64 encoded image
  66. elif image_url.startswith('data:image'):
  67. image = _load_image_from_data_url(image_url)
  68. else:
  69. raise ValueError(
  70. "Invalid 'image_url': A valid 'image_url' must start "
  71. "with either 'data:image' or 'http'.")
  72. return image
  73. async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
  74. image = await ImageFetchAiohttp.fetch_image(image_url)
  75. return {"image": image}
  76. def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
  77. """Encode a pillow image to base64 format."""
  78. buffered = BytesIO()
  79. if format == 'JPEG':
  80. image = image.convert('RGB')
  81. image.save(buffered, format)
  82. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  83. def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
  84. """Load image from base64 format."""
  85. return _load_image_from_bytes(base64.b64decode(image))
  86. def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
  87. """Rescale the dimensions of an image by a constant factor."""
  88. new_width = int(image.width * size_factor)
  89. new_height = int(image.height * size_factor)
  90. return image.resize((new_width, new_height))